From e4f9965069e9c6939927833f398172cac3fc7feb Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Tue, 8 Apr 2025 14:04:35 +0100 Subject: [PATCH] Arm backend: Add check to not partition ops with float64 input - Float64 placeholders are not supported in the Arm backend. They will cause a crash when processed in the process_placeholder function. This patch rejects Float64 placeholders early to prevent crashes during the partition. Change-Id: I7e7c61836a7a50f29c18252819cf5e537f3b595a Signed-off-by: Yufeng Shi --- .../tosa_supported_operators.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f76b30ce2a1..1cec35923de 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -112,6 +112,7 @@ def tosa_support_factory( # Negative checks: Remove nodes from partitioning negative_checks: list[OperatorSupportBase] = [ CheckInt64Inputs(exported_program, reporter), + CheckFloat64Inputs(exported_program, reporter), *[ reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}") for check in (additional_checks if additional_checks else []) @@ -443,3 +444,26 @@ def is_node_supported( ) return False return True + + +class CheckFloat64Inputs(OperatorSupportBase): + + def __init__( + self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter + ): + self.reporter = reporter + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + + for input_node in node.all_input_nodes: + tensor = get_first_fake_tensor(input_node) + if tensor.dtype == torch.float64: + self.reporter.report_reject( + node, + f"Had float64 input {input_node.name} that couldn't be handled.", + ) + return False + return True