Skip to content

Commit 7b14f26

Browse files
committed
Refactor TOSA support to use chain flow
This seems to match the intention of OperatorSupportBase and increases the flexibility of our support flow. New checks if different kinds are simply added to `tosa_support_factory` Signed-off-by: Erik Lundell <[email protected]> Change-Id: I4f5f1669959f35c5e2770da396a352e1730cb3e0
1 parent 8d96d74 commit 7b14f26

File tree

7 files changed

+42
-47
lines changed

7 files changed

+42
-47
lines changed

backends/arm/arm_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
ArmBackend,
1515
) # usort: skip
1616
from executorch.backends.arm.operator_support.tosa_supported_operators import (
17-
TOSASupportedOperators,
17+
tosa_support_factory,
1818
)
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -72,7 +72,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
7272

7373
capability_partitioner = CapabilityBasedPartitioner(
7474
exported_program.graph_module,
75-
TOSASupportedOperators(tosa_spec),
75+
tosa_support_factory(tosa_spec),
7676
allows_single_node_partition=True,
7777
)
7878
partition_list = capability_partitioner.propose_partitions()

backends/arm/operator_support/convolution_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
2424
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2525
]
2626

27-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
27+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2828

2929
# Not implemented
3030
transposed = cast(bool, node.args[6])

backends/arm/operator_support/pool_2d_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4343
TosaSpecification.create_from_string("TOSA-0.80+MI"),
4444
]
4545

46-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
46+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4747
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
4848
return True
4949

@@ -73,7 +73,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
7373
TosaSpecification.create_from_string("TOSA-0.80+MI"),
7474
]
7575

76-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
76+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
7777
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
7878
return True
7979

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SumSupported(SupportedTOSAOperatorCheck):
2323
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2424
]
2525

26-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
26+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2727
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
2828
return True
2929

backends/arm/operator_support/right_shift_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -29,7 +29,7 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
2929
TosaSpecification.create_from_string("TOSA-0.80+MI"),
3030
]
3131

32-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
32+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
3333

3434
# TODO MLETORCH-525 Remove warning
3535
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:

backends/arm/operator_support/to_copy_support.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def _merge_supported_types(
7070
)
7171
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}
7272

73-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
73+
def is_node_tosa_supported(
74+
self, node: fx.Node, tosa_spec: TosaSpecification
75+
) -> bool:
7476
assert node.target in self.targets
7577

7678
if tosa_spec not in self.tosa_specs:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,70 +6,77 @@
66
# pyre-unsafe
77

88
import operator
9-
from typing import Type
9+
from typing import final, Type
1010

1111
import torch.fx as fx
1212
from executorch.backends.arm.tosa_specification import TosaSpecification
1313
from executorch.exir.dialects._ops import ops as exir_ops
14-
from torch.fx.passes.operator_support import OperatorSupportBase
14+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
1515

1616

17-
class SupportedTOSAOperatorCheck:
17+
class SupportedTOSAOperatorCheck(OperatorSupportBase):
1818
"""
1919
Supported OP for TOSA lowering
2020
"""
2121

22+
def __init__(self, tosa_spec: TosaSpecification):
23+
self.tosa_spec = tosa_spec
24+
2225
# Should be populated by subclass implementation
2326
tosa_specs: list[TosaSpecification] = []
2427
targets: list[str] = []
2528

26-
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
29+
@final
30+
def is_node_supported(self, submodules, node: fx.Node) -> bool:
31+
if node.target not in self.targets:
32+
return False
33+
return self.is_node_tosa_supported(node, self.tosa_spec)
34+
35+
def is_node_tosa_supported(
36+
self, node: fx.Node, tosa_spec: TosaSpecification
37+
) -> bool:
2738
"""
2839
Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec.
29-
To be implemented by subclasses targeting
3040
"""
31-
raise NotImplementedError("NodeVisitor must be extended.")
41+
raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.")
3242

3343

3444
# container for all SupportedTosaOperatorCheck classes
35-
_tosa_spec_dicts: dict[
36-
TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
37-
] = {
38-
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
39-
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
45+
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
46+
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
47+
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
4048
}
4149

4250

43-
def register_tosa_support_check(checker):
51+
def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
4452
"""
4553
Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck
4654
to be registered for checking if a torch.fx.Node is lowerable given
4755
a TOSA specification.
4856
"""
4957
for tosa_spec in checker.tosa_specs:
50-
for target in checker.targets:
51-
_tosa_spec_dicts[tosa_spec][target] = checker
58+
_tosa_spec_support[tosa_spec].append(checker)
5259
return checker
5360

5461

5562
def get_registered_tosa_support_checks(
5663
tosa_spec: TosaSpecification,
57-
) -> dict[str, SupportedTOSAOperatorCheck]:
64+
) -> list[Type[SupportedTOSAOperatorCheck]]:
5865

59-
if tosa_spec not in _tosa_spec_dicts:
66+
if tosa_spec not in _tosa_spec_support:
6067
raise RuntimeError
6168

62-
tosa_support_checks = {}
63-
for target, tosa_check in _tosa_spec_dicts[tosa_spec].items():
64-
tosa_support_checks[target] = tosa_check()
69+
return _tosa_spec_support[tosa_spec]
6570

66-
return tosa_support_checks
6771

72+
def tosa_support_factory(tosa_spec: TosaSpecification) -> OperatorSupportBase:
73+
return any_chain(
74+
BaseTOSASupportList(),
75+
*(check(tosa_spec) for check in get_registered_tosa_support_checks(tosa_spec)),
76+
)
6877

69-
class TOSASupportedOperators(OperatorSupportBase):
70-
def __init__(self, tosa_spec: TosaSpecification):
71-
super().__init__()
72-
self.tosa_spec = tosa_spec
78+
79+
class BaseTOSASupportList(OperatorSupportBase):
7380

7481
def is_node_supported(self, submodules, node: fx.Node) -> bool:
7582
supported = node.op == "call_function" and node.target in [
@@ -123,18 +130,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
123130
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
124131
]
125132

126-
if not supported:
127-
supported = self.is_node_supported_custom(node)
128-
129-
# Override partitioning based on pre partition passes
130-
if "arm_override_partition" in node.meta:
131-
supported = supported & node.meta["arm_override_partition"]
132-
node.meta.pop("arm_override_partition")
133-
134133
return supported
135-
136-
def is_node_supported_custom(self, node: fx.Node) -> bool:
137-
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
138-
if node.target in tosa_checks.keys():
139-
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
140-
return False

0 commit comments

Comments
 (0)