|
6 | 6 | # pyre-unsafe |
7 | 7 |
|
8 | 8 | import operator |
9 | | -from typing import Type |
| 9 | +from typing import final, Type |
10 | 10 |
|
11 | 11 | import torch.fx as fx |
12 | 12 | from executorch.backends.arm.tosa_specification import TosaSpecification |
13 | 13 | from executorch.exir.dialects._ops import ops as exir_ops |
14 | | -from torch.fx.passes.operator_support import OperatorSupportBase |
| 14 | +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase |
15 | 15 |
|
16 | 16 |
|
17 | | -class SupportedTOSAOperatorCheck: |
| 17 | +class SupportedTOSAOperatorCheck(OperatorSupportBase): |
18 | 18 | """ |
19 | 19 | Supported OP for TOSA lowering |
20 | 20 | """ |
21 | 21 |
|
| 22 | + def __init__(self, tosa_spec: TosaSpecification): |
| 23 | + self.tosa_spec = tosa_spec |
| 24 | + |
22 | 25 | # Should be populated by subclass implementation |
23 | 26 | tosa_specs: list[TosaSpecification] = [] |
24 | 27 | targets: list[str] = [] |
25 | 28 |
|
26 | | - def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: |
| 29 | + @final |
| 30 | + def is_node_supported(self, submodules, node: fx.Node) -> bool: |
| 31 | + if node.target not in self.targets: |
| 32 | + return False |
| 33 | + return self.is_node_tosa_supported(node, self.tosa_spec) |
| 34 | + |
| 35 | + def is_node_tosa_supported( |
| 36 | + self, node: fx.Node, tosa_spec: TosaSpecification |
| 37 | + ) -> bool: |
27 | 38 | """ |
28 | 39 | Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. |
29 | | - To be implemented by subclasses targeting |
30 | 40 | """ |
31 | | - raise NotImplementedError("NodeVisitor must be extended.") |
| 41 | + raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.") |
32 | 42 |
|
33 | 43 |
|
34 | 44 | # container for all SupportedTosaOperatorCheck classes |
35 | | -_tosa_spec_dicts: dict[ |
36 | | - TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]] |
37 | | -] = { |
38 | | - TosaSpecification.create_from_string("TOSA-0.80+BI"): {}, |
39 | | - TosaSpecification.create_from_string("TOSA-0.80+MI"): {}, |
| 45 | +_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { |
| 46 | + TosaSpecification.create_from_string("TOSA-0.80+BI"): [], |
| 47 | + TosaSpecification.create_from_string("TOSA-0.80+MI"): [], |
40 | 48 | } |
41 | 49 |
|
42 | 50 |
|
43 | | -def register_tosa_support_check(checker): |
| 51 | +def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): |
44 | 52 | """ |
45 | 53 | Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck |
46 | 54 | to be registered for checking if a torch.fx.Node is lowerable given |
47 | 55 | a TOSA specification. |
48 | 56 | """ |
49 | 57 | for tosa_spec in checker.tosa_specs: |
50 | | - for target in checker.targets: |
51 | | - _tosa_spec_dicts[tosa_spec][target] = checker |
| 58 | + _tosa_spec_support[tosa_spec].append(checker) |
52 | 59 | return checker |
53 | 60 |
|
54 | 61 |
|
55 | 62 | def get_registered_tosa_support_checks( |
56 | 63 | tosa_spec: TosaSpecification, |
57 | | -) -> dict[str, SupportedTOSAOperatorCheck]: |
| 64 | +) -> list[Type[SupportedTOSAOperatorCheck]]: |
58 | 65 |
|
59 | | - if tosa_spec not in _tosa_spec_dicts: |
| 66 | + if tosa_spec not in _tosa_spec_support: |
60 | 67 | raise RuntimeError |
61 | 68 |
|
62 | | - tosa_support_checks = {} |
63 | | - for target, tosa_check in _tosa_spec_dicts[tosa_spec].items(): |
64 | | - tosa_support_checks[target] = tosa_check() |
| 69 | + return _tosa_spec_support[tosa_spec] |
65 | 70 |
|
66 | | - return tosa_support_checks |
67 | 71 |
|
| 72 | +def tosa_support_factory(tosa_spec: TosaSpecification) -> OperatorSupportBase: |
| 73 | + return any_chain( |
| 74 | + BaseTOSASupportList(), |
| 75 | + *(check(tosa_spec) for check in get_registered_tosa_support_checks(tosa_spec)), |
| 76 | + ) |
68 | 77 |
|
69 | | -class TOSASupportedOperators(OperatorSupportBase): |
70 | | - def __init__(self, tosa_spec: TosaSpecification): |
71 | | - super().__init__() |
72 | | - self.tosa_spec = tosa_spec |
| 78 | + |
| 79 | +class BaseTOSASupportList(OperatorSupportBase): |
73 | 80 |
|
74 | 81 | def is_node_supported(self, submodules, node: fx.Node) -> bool: |
75 | 82 | supported = node.op == "call_function" and node.target in [ |
@@ -123,18 +130,4 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: |
123 | 130 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
124 | 131 | ] |
125 | 132 |
|
126 | | - if not supported: |
127 | | - supported = self.is_node_supported_custom(node) |
128 | | - |
129 | | - # Override partitioning based on pre partition passes |
130 | | - if "arm_override_partition" in node.meta: |
131 | | - supported = supported & node.meta["arm_override_partition"] |
132 | | - node.meta.pop("arm_override_partition") |
133 | | - |
134 | 133 | return supported |
135 | | - |
136 | | - def is_node_supported_custom(self, node: fx.Node) -> bool: |
137 | | - tosa_checks = get_registered_tosa_support_checks(self.tosa_spec) |
138 | | - if node.target in tosa_checks.keys(): |
139 | | - return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index] |
140 | | - return False |
0 commit comments