Skip to content

Commit 2584dc4

Browse files
committed
Add additional checks for operator support.
This can be used to avoid partitioning parts of a model when debugging. Though any OperatorSupportBase can be used, we add three OperatorSupport as utilities: DontPartition: Don't partition based on node target DontPartitionName: Don't partition based on node name DontPartitionModule: Don't partition based on which module the op comes from. All these checks can match parts of the target name, and save a list of the nodes they reject for debugging. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I0b2537370da4aadcffbb87c52cb98e82d78cf27f
1 parent 7b14f26 commit 2584dc4

File tree

5 files changed

+369
-10
lines changed

5 files changed

+369
-10
lines changed

backends/arm/arm_partitioner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging
99
import os
10-
from typing import Callable, final, List, Optional, Tuple
10+
from typing import Callable, final, List, Optional, Sequence, Tuple
1111

1212
import torch
1313
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
@@ -27,6 +27,8 @@
2727
from executorch.exir.dialects._ops import ops as exir_ops
2828
from torch.export.exported_program import ExportedProgram
2929
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
30+
from torch.fx.passes.operator_support import OperatorSupportBase
31+
3032

3133
logger = logging.getLogger(__name__)
3234
logger.setLevel(logging.WARNING)
@@ -54,8 +56,13 @@ def is_dequant_node(node: torch.fx.node.Node) -> bool:
5456

5557
@final
5658
class ArmPartitioner(Partitioner):
57-
def __init__(self, compile_spec: List[CompileSpec]) -> None:
59+
def __init__(
60+
self,
61+
compile_spec: List[CompileSpec],
62+
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
63+
) -> None:
5864
self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec)
65+
self.additional_checks = additional_checks
5966

6067
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
6168
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -72,7 +79,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7279

7380
capability_partitioner = CapabilityBasedPartitioner(
7481
exported_program.graph_module,
75-
tosa_support_factory(tosa_spec),
82+
tosa_support_factory(tosa_spec, self.additional_checks),
7683
allows_single_node_partition=True,
7784
)
7885
partition_list = capability_partitioner.propose_partitions()

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
# pyre-unsafe
77

88
import operator
9-
from typing import final, Type
9+
from typing import final, Optional, Sequence, Type
1010

1111
import torch.fx as fx
1212
from executorch.backends.arm.tosa_specification import TosaSpecification
1313
from executorch.exir.dialects._ops import ops as exir_ops
14-
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
14+
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
1515

1616

1717
class SupportedTOSAOperatorCheck(OperatorSupportBase):
@@ -69,10 +69,19 @@ def get_registered_tosa_support_checks(
6969
return _tosa_spec_support[tosa_spec]
7070

7171

72-
def tosa_support_factory(tosa_spec: TosaSpecification) -> OperatorSupportBase:
73-
return any_chain(
74-
BaseTOSASupportList(),
75-
*(check(tosa_spec) for check in get_registered_tosa_support_checks(tosa_spec)),
72+
def tosa_support_factory(
73+
tosa_spec: TosaSpecification,
74+
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
75+
) -> OperatorSupportBase:
76+
return chain(
77+
any_chain(
78+
BaseTOSASupportList(),
79+
*(
80+
check(tosa_spec)
81+
for check in get_registered_tosa_support_checks(tosa_spec)
82+
),
83+
),
84+
*additional_checks if additional_checks else [],
7685
)
7786

7887

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.arm_partitioner import ArmPartitioner
8+
from executorch.backends.arm.test import common
9+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
10+
from executorch.exir.backend.operator_support import (
11+
DontPartition,
12+
DontPartitionModule,
13+
DontPartitionName,
14+
)
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
class CustomPartitioning(torch.nn.Module):
19+
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
20+
21+
def forward(self, x: torch.Tensor, y: torch.Tensor):
22+
z = x + y
23+
s = torch.sigmoid(z)
24+
return s * z
25+
26+
27+
class NestedModule(torch.nn.Module):
28+
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
29+
30+
def __init__(self):
31+
super().__init__()
32+
self.nested = CustomPartitioning()
33+
34+
def forward(self, x: torch.Tensor, y: torch.Tensor):
35+
a = x.sigmoid()
36+
b = a + y
37+
return self.nested(a, b)
38+
39+
40+
def test_single_reject():
41+
module = CustomPartitioning()
42+
inputs = module.inputs
43+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
44+
check = DontPartition(exir_ops.edge.aten.sigmoid.default)
45+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
46+
(
47+
ArmTester(
48+
module,
49+
example_inputs=inputs,
50+
compile_spec=compile_spec,
51+
)
52+
.export()
53+
.to_edge_transform_and_lower(partitioners=[partitioner])
54+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
55+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
56+
.to_executorch()
57+
.run_method_and_compare_outputs(inputs=inputs)
58+
)
59+
assert check.has_rejected_node()
60+
61+
62+
def test_multiple_reject():
63+
module = CustomPartitioning()
64+
inputs = module.inputs
65+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
66+
check = DontPartition(
67+
exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mul.Tensor
68+
)
69+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
70+
(
71+
ArmTester(
72+
module,
73+
example_inputs=inputs,
74+
compile_spec=compile_spec,
75+
)
76+
.export()
77+
.to_edge_transform_and_lower(partitioners=[partitioner])
78+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
79+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
80+
.to_executorch()
81+
.run_method_and_compare_outputs(inputs=inputs)
82+
)
83+
assert check.has_rejected_node()
84+
85+
86+
def test_torch_op_reject():
87+
module = CustomPartitioning()
88+
inputs = module.inputs
89+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
90+
check = DontPartition(torch.ops.aten.sigmoid.default)
91+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
92+
(
93+
ArmTester(
94+
module,
95+
example_inputs=inputs,
96+
compile_spec=compile_spec,
97+
)
98+
.export()
99+
.to_edge_transform_and_lower(partitioners=[partitioner])
100+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
101+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
102+
.to_executorch()
103+
.run_method_and_compare_outputs(inputs=inputs)
104+
)
105+
assert check.has_rejected_node()
106+
107+
108+
def test_string_op_reject():
109+
module = CustomPartitioning()
110+
inputs = module.inputs
111+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
112+
check = DontPartition("aten.sigmoid.default")
113+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
114+
(
115+
ArmTester(
116+
module,
117+
example_inputs=inputs,
118+
compile_spec=compile_spec,
119+
)
120+
.export()
121+
.to_edge_transform_and_lower(partitioners=[partitioner])
122+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
123+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
124+
.to_executorch()
125+
.run_method_and_compare_outputs(inputs=inputs)
126+
)
127+
128+
assert check.has_rejected_node()
129+
130+
131+
def test_name_reject():
132+
module = CustomPartitioning()
133+
inputs = module.inputs
134+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
135+
check = DontPartitionName("mul", "sigmoid", exact=False)
136+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
137+
(
138+
ArmTester(
139+
module,
140+
example_inputs=inputs,
141+
compile_spec=compile_spec,
142+
)
143+
.export()
144+
.to_edge_transform_and_lower(partitioners=[partitioner])
145+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
146+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
147+
.to_executorch()
148+
.run_method_and_compare_outputs(inputs=inputs)
149+
)
150+
assert check.has_rejected_node()
151+
152+
153+
def test_module_reject():
154+
module = NestedModule()
155+
inputs = module.inputs
156+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
157+
check = DontPartitionModule(module_name="CustomPartitioning")
158+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
159+
(
160+
ArmTester(
161+
module,
162+
example_inputs=inputs,
163+
compile_spec=compile_spec,
164+
)
165+
.export()
166+
.to_edge_transform_and_lower(partitioners=[partitioner])
167+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
168+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
169+
.to_executorch()
170+
.run_method_and_compare_outputs(inputs=inputs)
171+
)
172+
assert check.has_rejected_node()
173+
174+
175+
def test_inexact_module_reject():
176+
module = NestedModule()
177+
inputs = module.inputs
178+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
179+
check = DontPartitionModule(module_name="Custom", exact=False)
180+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
181+
(
182+
ArmTester(
183+
module,
184+
example_inputs=inputs,
185+
compile_spec=compile_spec,
186+
)
187+
.export()
188+
.to_edge_transform_and_lower(partitioners=[partitioner])
189+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
190+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
191+
.to_executorch()
192+
.run_method_and_compare_outputs(inputs=inputs)
193+
)
194+
assert check.has_rejected_node()
195+
196+
197+
def test_module_instance_reject():
198+
module = NestedModule()
199+
inputs = module.inputs
200+
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
201+
check = DontPartitionModule(instance_name="nested")
202+
partitioner = ArmPartitioner(compile_spec, additional_checks=[check])
203+
(
204+
ArmTester(
205+
module,
206+
example_inputs=inputs,
207+
compile_spec=compile_spec,
208+
)
209+
.export()
210+
.to_edge_transform_and_lower(partitioners=[partitioner])
211+
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
212+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
213+
.to_executorch()
214+
.run_method_and_compare_outputs(inputs=inputs)
215+
)
216+
assert check.has_rejected_node()

backends/arm/test/tester/arm_tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class ArmTester(Tester):
228228
def __init__(
229229
self,
230230
model: torch.nn.Module,
231-
example_inputs: Tuple[torch.Tensor],
231+
example_inputs: Tuple,
232232
compile_spec: List[CompileSpec],
233233
):
234234
"""

0 commit comments

Comments
 (0)