From 7b14f263db5315fdd7a978bcd7dce23d0dc6ca57 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 30 Jan 2025 15:17:28 +0100 Subject: [PATCH 1/2] Refactor TOSA support to use chain flow This seems to match the intention of OperatorSupportBase and increases the flexibility of our support flow. New checks if different kinds are simply added to `tosa_support_factory` Signed-off-by: Erik Lundell Change-Id: I4f5f1669959f35c5e2770da396a352e1730cb3e0 --- backends/arm/arm_partitioner.py | 4 +- .../operator_support/convolution_support.py | 2 +- .../arm/operator_support/pool_2d_support.py | 4 +- .../operator_support/reduce_sum_support.py | 2 +- .../operator_support/right_shift_support.py | 4 +- .../arm/operator_support/to_copy_support.py | 4 +- .../tosa_supported_operators.py | 69 +++++++++---------- 7 files changed, 42 insertions(+), 47 deletions(-) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 8fde8dff610..a9bc92233f0 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -14,7 +14,7 @@ ArmBackend, ) # usort: skip from executorch.backends.arm.operator_support.tosa_supported_operators import ( - TOSASupportedOperators, + tosa_support_factory, ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -72,7 +72,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - TOSASupportedOperators(tosa_spec), + tosa_support_factory(tosa_spec), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index ffa74942fa6..0d0a32200e8 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -24,7 +24,7 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # Not implemented transposed = cast(bool, node.args[6]) diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index ae3c7120731..7aa35a721b6 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -43,7 +43,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): return True @@ -73,7 +73,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): return True diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 1a337be2da1..8345d69caaa 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -23,7 +23,7 @@ class SumSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): return True diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index c0280919693..44aa2a38af5 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,7 +29,7 @@ class RightShiftSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-0.80+MI"), ] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # TODO MLETORCH-525 Remove warning if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index 27c5f24f2ee..c81c8e58a29 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -70,7 +70,9 @@ def _merge_supported_types( ) POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: assert node.target in self.targets if tosa_spec not in self.tosa_specs: diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 237da6214e8..8ad1f90a79e 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -6,70 +6,77 @@ # pyre-unsafe import operator -from typing import Type +from typing import final, Type import torch.fx as fx from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase -class SupportedTOSAOperatorCheck: +class SupportedTOSAOperatorCheck(OperatorSupportBase): """ Supported OP for TOSA lowering """ + def __init__(self, tosa_spec: TosaSpecification): + self.tosa_spec = tosa_spec + # Should be populated by subclass implementation tosa_specs: list[TosaSpecification] = [] targets: list[str] = [] - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: + @final + def is_node_supported(self, submodules, node: fx.Node) -> bool: + if node.target not in self.targets: + return False + return self.is_node_tosa_supported(node, self.tosa_spec) + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: """ Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. - To be implemented by subclasses targeting """ - raise NotImplementedError("NodeVisitor must be extended.") + raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.") # container for all SupportedTosaOperatorCheck classes -_tosa_spec_dicts: dict[ - TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]] -] = { - TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, - TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, +_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { + TosaSpecification.create_from_string("TOSA-0.80+BI"): [], + TosaSpecification.create_from_string("TOSA-0.80+MI"): [], } -def register_tosa_support_check(checker): +def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): """ Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck to be registered for checking if a torch.fx.Node is lowerable given a TOSA specification. """ for tosa_spec in checker.tosa_specs: - for target in checker.targets: - _tosa_spec_dicts[tosa_spec][target] = checker + _tosa_spec_support[tosa_spec].append(checker) return checker def get_registered_tosa_support_checks( tosa_spec: TosaSpecification, -) -> dict[str, SupportedTOSAOperatorCheck]: +) -> list[Type[SupportedTOSAOperatorCheck]]: - if tosa_spec not in _tosa_spec_dicts: + if tosa_spec not in _tosa_spec_support: raise RuntimeError - tosa_support_checks = {} - for target, tosa_check in _tosa_spec_dicts[tosa_spec].items(): - tosa_support_checks[target] = tosa_check() + return _tosa_spec_support[tosa_spec] - return tosa_support_checks +def tosa_support_factory(tosa_spec: TosaSpecification) -> OperatorSupportBase: + return any_chain( + BaseTOSASupportList(), + *(check(tosa_spec) for check in get_registered_tosa_support_checks(tosa_spec)), + ) -class TOSASupportedOperators(OperatorSupportBase): - def __init__(self, tosa_spec: TosaSpecification): - super().__init__() - self.tosa_spec = tosa_spec + +class BaseTOSASupportList(OperatorSupportBase): def is_node_supported(self, submodules, node: fx.Node) -> bool: supported = node.op == "call_function" and node.target in [ @@ -123,18 +130,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, ] - if not supported: - supported = self.is_node_supported_custom(node) - - # Override partitioning based on pre partition passes - if "arm_override_partition" in node.meta: - supported = supported & node.meta["arm_override_partition"] - node.meta.pop("arm_override_partition") - return supported - - def is_node_supported_custom(self, node: fx.Node) -> bool: - tosa_checks = get_registered_tosa_support_checks(self.tosa_spec) - if node.target in tosa_checks.keys(): - return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index] - return False From 2584dc4724d6b92aa0cfa4f1e7201fe0564722a9 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 30 Jan 2025 16:33:36 +0100 Subject: [PATCH 2/2] Add additional checks for operator support. This can be used to avoid partitioning parts of a model when debugging. Though any OperatorSupportBase can be used, we add three OperatorSupport as utilities: DontPartition: Don't partition based on node target DontPartitionName: Don't partition based on node name DontPartitionModule: Don't partition based on which module the op comes from. All these checks can match parts of the target name, and save a list of the nodes they reject for debugging. Signed-off-by: Erik Lundell Change-Id: I0b2537370da4aadcffbb87c52cb98e82d78cf27f --- backends/arm/arm_partitioner.py | 13 +- .../tosa_supported_operators.py | 21 +- .../arm/test/misc/test_custom_partition.py | 216 ++++++++++++++++++ backends/arm/test/tester/arm_tester.py | 2 +- exir/backend/operator_support.py | 127 ++++++++++ 5 files changed, 369 insertions(+), 10 deletions(-) create mode 100644 backends/arm/test/misc/test_custom_partition.py create mode 100644 exir/backend/operator_support.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index a9bc92233f0..b282139b7f2 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -7,7 +7,7 @@ import logging import os -from typing import Callable, final, List, Optional, Tuple +from typing import Callable, final, List, Optional, Sequence, Tuple import torch from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined] @@ -27,6 +27,8 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupportBase + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -54,8 +56,13 @@ def is_dequant_node(node: torch.fx.node.Node) -> bool: @final class ArmPartitioner(Partitioner): - def __init__(self, compile_spec: List[CompileSpec]) -> None: + def __init__( + self, + compile_spec: List[CompileSpec], + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, + ) -> None: self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec) + self.additional_checks = additional_checks def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible @@ -72,7 +79,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - tosa_support_factory(tosa_spec), + tosa_support_factory(tosa_spec, self.additional_checks), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 8ad1f90a79e..43ab9ea10b5 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -6,12 +6,12 @@ # pyre-unsafe import operator -from typing import final, Type +from typing import final, Optional, Sequence, Type import torch.fx as fx from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx.passes.operator_support import any_chain, OperatorSupportBase +from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase class SupportedTOSAOperatorCheck(OperatorSupportBase): @@ -69,10 +69,19 @@ def get_registered_tosa_support_checks( return _tosa_spec_support[tosa_spec] -def tosa_support_factory(tosa_spec: TosaSpecification) -> OperatorSupportBase: - return any_chain( - BaseTOSASupportList(), - *(check(tosa_spec) for check in get_registered_tosa_support_checks(tosa_spec)), +def tosa_support_factory( + tosa_spec: TosaSpecification, + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, +) -> OperatorSupportBase: + return chain( + any_chain( + BaseTOSASupportList(), + *( + check(tosa_spec) + for check in get_registered_tosa_support_checks(tosa_spec) + ), + ), + *additional_checks if additional_checks else [], ) diff --git a/backends/arm/test/misc/test_custom_partition.py b/backends/arm/test/misc/test_custom_partition.py new file mode 100644 index 00000000000..fec15368351 --- /dev/null +++ b/backends/arm/test/misc/test_custom_partition.py @@ -0,0 +1,216 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.arm_partitioner import ArmPartitioner +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.operator_support import ( + DontPartition, + DontPartitionModule, + DontPartitionName, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class CustomPartitioning(torch.nn.Module): + inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + z = x + y + s = torch.sigmoid(z) + return s * z + + +class NestedModule(torch.nn.Module): + inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5)) + + def __init__(self): + super().__init__() + self.nested = CustomPartitioning() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + a = x.sigmoid() + b = a + y + return self.nested(a, b) + + +def test_single_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition(exir_ops.edge.aten.sigmoid.default) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_multiple_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition( + exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mul.Tensor + ) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_torch_op_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition(torch.ops.aten.sigmoid.default) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_string_op_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartition("aten.sigmoid.default") + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + + assert check.has_rejected_node() + + +def test_name_reject(): + module = CustomPartitioning() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionName("mul", "sigmoid", exact=False) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_module_reject(): + module = NestedModule() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionModule(module_name="CustomPartitioning") + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_inexact_module_reject(): + module = NestedModule() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionModule(module_name="Custom", exact=False) + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() + + +def test_module_instance_reject(): + module = NestedModule() + inputs = module.inputs + compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI") + check = DontPartitionModule(instance_name="nested") + partitioner = ArmPartitioner(compile_spec, additional_checks=[check]) + ( + ArmTester( + module, + example_inputs=inputs, + compile_spec=compile_spec, + ) + .export() + .to_edge_transform_and_lower(partitioners=[partitioner]) + .check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=inputs) + ) + assert check.has_rejected_node() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 11e7d863043..4c74cb656c8 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -228,7 +228,7 @@ class ArmTester(Tester): def __init__( self, model: torch.nn.Module, - example_inputs: Tuple[torch.Tensor], + example_inputs: Tuple, compile_spec: List[CompileSpec], ): """ diff --git a/exir/backend/operator_support.py b/exir/backend/operator_support.py new file mode 100644 index 00000000000..5dc0baba756 --- /dev/null +++ b/exir/backend/operator_support.py @@ -0,0 +1,127 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import fx +from torch.fx.passes.operator_support import OperatorSupportBase + + +def _compare(exact: bool, search_for: str | None, search_in: str) -> bool: + """Check whether the search_for str matches the search_in str. + Match can mean "identical" or "part of" depending on the `exact` flag. + """ + if not search_for: + return False + if exact: + return search_for == search_in + else: + return search_for in search_in + + +class DontSupportBase(OperatorSupportBase): + _rejected_nodes: list[fx.Node] = [] + + def reject_node(self, node: fx.Node): + self._rejected_nodes.append(node) + + def rejected_nodes(self): + return self._rejected_nodes + + def has_rejected_node(self) -> bool: + return self.num_rejected() > 0 + + def num_rejected(self) -> int: + return len(self._rejected_nodes) + + +class DontPartition(DontSupportBase): + """Operator check to skip partitioning ops based on their target. + The target can be an EdgeOverloadOp (exir_ops.edge.aten.*), + OverloadOp (torch.ops.aten.*), or a string ("aten.*"). + + For the string case, set `exact` to False to match only part of the name. + """ + + def __init__(self, *targets, exact: bool = True): + self.targets = targets + self.exact = exact + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + if node.target in self.targets: + self.reject_node(node) + return False + + if "original_aten" not in node.meta: + return True + stringified_node_target = str(node.meta["original_aten"]) + for target in self.targets: + if _compare(self.exact, str(target), stringified_node_target): + self.reject_node(node) + return False + return True + + +class DontPartitionName(DontSupportBase): + """Operator check to skip partitioning ops based on their name, which can be found + by for example node.name or print-outs of a GraphModule. + + Set `exact` to False to match only part of the name. + """ + + def __init__(self, *targets, exact: bool = True): + self.targets = targets + self.exact = exact + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + for target in self.targets: + if _compare(self.exact, target, node.name): + self.reject_node(node) + return False + return True + + +class DontPartitionModule(DontSupportBase): + """Operator check to skip partitioning modules. + You can pass either the module name, i.e. the class name of the module, + or the name of the instance that you want to skip. + If module_name contains a dot, the full module name of checked nodes is used, + if it does not, only part after the last dot is used. + + For example, you could have two files defining MyClass, which have the full module name: + my_file.MyClass + my_other_file.MyClass + If you would call DontPartitionModule with module_name="MyClass", you would skip partitioning both. + With "my_file.MyClass", you would only target the first class. + + Set `exact` to False to match only part of the name. + """ + + def __init__( + self, + *, + module_name: str | None = None, + instance_name: str | None = None, + exact: bool = True, + ): + self.module_name = module_name + self.instance_name = instance_name + self.exact = exact + self.used_dotted = "." in module_name if module_name else True + + def is_node_supported(self, submodules, node: fx.Node) -> bool: + if "nn_module_stack" not in node.meta: + return True + + for module_meta in node.meta["nn_module_stack"].values(): + if _compare(self.exact, self.instance_name, module_meta[0]): + self.reject_node(node) + return False + node_module_name = module_meta[1] + if not self.used_dotted: + node_module_name = node_module_name.split(".")[-1] + if _compare(self.exact, self.module_name, node_module_name): + self.reject_node(node) + return False + + return True