diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 8fde8dff610..b282139b7f2 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -7,14 +7,14 @@ import logging import os -from typing import Callable, final, List, Optional, Tuple +from typing import Callable, final, List, Optional, Sequence, Tuple import torch from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined] ArmBackend, ) # usort: skip from executorch.backends.arm.operator_support.tosa_supported_operators import ( - TOSASupportedOperators, + tosa_support_factory, ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -27,6 +27,8 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupportBase + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -54,8 +56,13 @@ def is_dequant_node(node: torch.fx.node.Node) -> bool: @final class ArmPartitioner(Partitioner): - def __init__(self, compile_spec: List[CompileSpec]) -> None: + def __init__( + self, + compile_spec: List[CompileSpec], + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, + ) -> None: self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec) + self.additional_checks = additional_checks def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible @@ -72,7 +79,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - TOSASupportedOperators(tosa_spec), + tosa_support_factory(tosa_spec, self.additional_checks), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index ffa74942fa6..0d0a32200e8 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -24,7 +24,7 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # Not implemented transposed = cast(bool, node.args[6]) diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index ae3c7120731..7aa35a721b6 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -43,7 +43,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): return True @@ -73,7 +73,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): return True diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 1a337be2da1..8345d69caaa 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -23,7 +23,7 @@ class SumSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): return True diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index c0280919693..44aa2a38af5 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,7 +29,7 @@ class RightShiftSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # TODO MLETORCH-525 Remove warning if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index 27c5f24f2ee..c81c8e58a29 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -70,7 +70,9 @@ def _merge_supported_types( ) POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: assert node.target in self.targets if tosa_spec not in self.tosa_specs: diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 237da6214e8..43ab9ea10b5 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -6,70 +6,86 @@ # pyre-unsafe import operator -from typing import Type +from typing import final, Optional, Sequence, Type import torch.fx as fx from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase -class SupportedTOSAOperatorCheck: +class SupportedTOSAOperatorCheck(OperatorSupportBase): """ Supported OP for TOSA lowering """ + def __init__(self, tosa_spec: TosaSpecification): + self.tosa_spec = tosa_spec + # Should be populated by subclass implementation tosa_specs: list[TosaSpecification] = [] targets: list[str] = [] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + @final + def is_node_supported(self, submodules, node: fx.Node) -> bool: + if node.target not in self.targets: + return False + return self.is_node_tosa_supported(node, self.tosa_spec) + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: """ Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. - To be implemented by subclasses targeting """ - raise NotImplementedError("NodeVisitor must be extended.") + raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.") # container for all SupportedTosaOperatorCheck classes -_tosa_spec_dicts: dict[ - TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]] -] = { - TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, - TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, +_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { + TosaSpecification.create_from_string("TOSA-0.80+BI"): [], + TosaSpecification.create_from_string("TOSA-0.80+MI"): [], } -def register_tosa_support_check(checker): +def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): """ Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck to be registered for checking if a torch.fx.Node is lowerable given a TOSA specification. """ for tosa_spec in checker.tosa_specs: - for target in checker.targets: - _tosa_spec_dicts[tosa_spec][target] = checker + _tosa_spec_support[tosa_spec].append(checker) return checker def get_registered_tosa_support_checks( tosa_spec: TosaSpecification, -) -> dict[str, SupportedTOSAOperatorCheck]: +) -> list[Type[SupportedTOSAOperatorCheck]]: - if tosa_spec not in _tosa_spec_dicts: + if tosa_spec not in _tosa_spec_support: raise RuntimeError - tosa_support_checks = {} - for target, tosa_check in _tosa_spec_dicts[tosa_spec].items(): - tosa_support_checks[target] = tosa_check() - - return tosa_support_checks + return _tosa_spec_support[tosa_spec] -class TOSASupportedOperators(OperatorSupportBase): - def __init__(self, tosa_spec: TosaSpecification): - super().__init__() - self.tosa_spec = tosa_spec +def tosa_support_factory( + tosa_spec: TosaSpecification, + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, +) -> OperatorSupportBase: + return chain( + any_chain( + BaseTOSASupportList(), + *( + check(tosa_spec) + for check in get_registered_tosa_support_checks(tosa_spec) + ), + ), + *additional_checks if additional_checks else [], + ) + + +class BaseTOSASupportList(OperatorSupportBase): def is_node_supported(self, submodules, node: fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ @@ -123,18 +139,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, ] - if not supported: - supported = self.is_node_supported_custom(node) - - # Override partitioning based on pre partition passes - if "arm_override_partition" in node.meta: - supported = supported & node.meta["arm_override_partition"] - node.meta.pop("arm_override_partition") - return supported - - def is_node_supported_custom(self, node: fx.Node) -> bool: - tosa_checks = get_registered_tosa_support_checks(self.tosa_spec) - if node.target in tosa_checks.keys(): - return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index] - return False diff --git a/backends/arm/test/misc/test_custom_partition.py b/backends/arm/test/misc/test_custom_partition.py new file mode 100644 index 00000000000..fec15368351 --- /dev/null +++ b/backends/arm/test/misc/test_custom_partition.py @@ -0,0 +1,216 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.arm_partitioner import ArmPartitioner +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.operator_support import ( + DontPartition, + DontPartitionModule, + DontPartitionName, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class CustomPartitioning(torch.nn.Module): + inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + z = x + y + s = torch.sigmoid(z) + return s * z + + +class NestedModule(torch.nn.Module): + inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5)) + + def __init__(self): + super().__init__() + self.nested = CustomPartitioning() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x.sigmoid() + b = a + y + return self.nested(a, b) + + +def test_single_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition(exir_ops.edge.aten.sigmoid.default) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_multiple_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition( + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mul.Tensor + ) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_torch_op_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition(torch.ops.aten.sigmoid.default) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_string_op_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition("aten.sigmoid.default") + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + + assert check.has_rejected_node() + + +def test_name_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionName("mul", "sigmoid", exact=False) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_module_reject(): + module = NestedModule() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionModule(module_name="CustomPartitioning") + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_inexact_module_reject(): + module = NestedModule() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionModule(module_name="Custom", exact=False) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_module_instance_reject(): + module = NestedModule() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionModule(instance_name="nested") + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 11e7d863043..4c74cb656c8 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -228,7 +228,7 @@ class ArmTester(Tester): def __init__( self, model: torch.nn.Module, - example_inputs: Tuple[torch.Tensor], + example_inputs: Tuple, compile_spec: List[CompileSpec], ): """ diff --git a/exir/backend/operator_support.py b/exir/backend/operator_support.py new file mode 100644 index 00000000000..5dc0baba756 --- /dev/null +++ b/exir/backend/operator_support.py @@ -0,0 +1,127 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import fx +from torch.fx.passes.operator_support import OperatorSupportBase + + +def _compare(exact: bool, search_for: str | None, search_in: str) -> bool: + """Check whether the search_for str matches the search_in str. + Match can mean "identical" or "part of" depending on the `exact` flag. + """ + if not search_for: + return False + if exact: + return search_for == search_in + else: + return search_for in search_in + + +class DontSupportBase(OperatorSupportBase): + _rejected_nodes: list[fx.Node] = [] + + def reject_node(self, node: fx.Node): + self._rejected_nodes.append(node) + + def rejected_nodes(self): + return self._rejected_nodes + + def has_rejected_node(self) -> bool: + return self.num_rejected() > 0 + + def num_rejected(self) -> int: + return len(self._rejected_nodes) + + +class DontPartition(DontSupportBase): + """Operator check to skip partitioning ops based on their target. + The target can be an EdgeOverloadOp (exir_ops.edge.aten.*), + OverloadOp (torch.ops.aten.*), or a string ("aten.*"). + + For the string case, set `exact` to False to match only part of the name. + """ + + def __init__(self, *targets, exact: bool = True): + self.targets = targets + self.exact = exact + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + if node.target in self.targets: + self.reject_node(node) + return False + + if "original_aten" not in node.meta: + return True + stringified_node_target = str(node.meta["original_aten"]) + for target in self.targets: + if _compare(self.exact, str(target), stringified_node_target): + self.reject_node(node) + return False + return True + + +class DontPartitionName(DontSupportBase): + """Operator check to skip partitioning ops based on their name, which can be found + by for example node.name or print-outs of a GraphModule. + + Set `exact` to False to match only part of the name. + """ + + def __init__(self, *targets, exact: bool = True): + self.targets = targets + self.exact = exact + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + for target in self.targets: + if _compare(self.exact, target, node.name): + self.reject_node(node) + return False + return True + + +class DontPartitionModule(DontSupportBase): + """Operator check to skip partitioning modules. + You can pass either the module name, i.e. the class name of the module, + or the name of the instance that you want to skip. + If module_name contains a dot, the full module name of checked nodes is used, + if it does not, only part after the last dot is used. + + For example, you could have two files defining MyClass, which have the full module name: + my_file.MyClass + my_other_file.MyClass + If you would call DontPartitionModule with module_name="MyClass", you would skip partitioning both. + With "my_file.MyClass", you would only target the first class. + + Set `exact` to False to match only part of the name. + """ + + def __init__( + self, + *, + module_name: str | None = None, + instance_name: str | None = None, + exact: bool = True, + ): + self.module_name = module_name + self.instance_name = instance_name + self.exact = exact + self.used_dotted = "." in module_name if module_name else True + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + if "nn_module_stack" not in node.meta: + return True + + for module_meta in node.meta["nn_module_stack"].values(): + if _compare(self.exact, self.instance_name, module_meta[0]): + self.reject_node(node) + return False + node_module_name = module_meta[1] + if not self.used_dotted: + node_module_name = node_module_name.split(".")[-1] + if _compare(self.exact, self.module_name, node_module_name): + self.reject_node(node) + return False + + return True