diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f76b30ce2a1..1cec35923de 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -112,6 +112,7 @@ def tosa_support_factory( # Negative checks: Remove nodes from partitioning negative_checks: list[OperatorSupportBase] = [ CheckInt64Inputs(exported_program, reporter), + CheckFloat64Inputs(exported_program, reporter), *[ reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}") for check in (additional_checks if additional_checks else []) @@ -443,3 +444,26 @@ def is_node_supported( ) return False return True + + +class CheckFloat64Inputs(OperatorSupportBase): + + def __init__( + self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter + ): + self.reporter = reporter + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + + for input_node in node.all_input_nodes: + tensor = get_first_fake_tensor(input_node) + if tensor.dtype == torch.float64: + self.reporter.report_reject( + node, + f"Had float64 input {input_node.name} that couldn't be handled.", + ) + return False + return True