|
8 | 8 | from executorch.backends.nxp.backend.custom_delegation_options import ( |
9 | 9 | CustomDelegationOptions, |
10 | 10 | ) |
| 11 | +from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node |
11 | 12 | from executorch.backends.nxp.backend.ir.converter.conversion import translator |
12 | 13 | from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( |
13 | 14 | apply_permutation_to, |
|
24 | 25 | from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec |
25 | 26 | from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT |
26 | 27 | from torch.fx import Node |
| 28 | +from torch.fx.passes.infra.partitioner import Partition |
27 | 29 | from torch.nn import Parameter |
28 | 30 |
|
29 | 31 |
|
@@ -86,30 +88,6 @@ def _is_supported_on_target( |
86 | 88 |
|
87 | 89 | dim = CatConverter._get_normalized_dim(node) |
88 | 90 |
|
89 | | - # There is a bug in the NeutronConverter, where if none of the dimensions before the one referenced by |
90 | | - # `dim` are `!= 1`, the `Concat` is not delegated. |
91 | | - # This only happens when the inputs to the `Concat` are model inputs, and not outputs of other |
92 | | - # operators. However, such a requirement cannot be currently enforced during partitioning, so we must |
93 | | - # take the conservative approach and decline delegation. |
94 | | - input_shapes = [list(n.meta["val"].shape) for n in node.args[0]] |
95 | | - if node.meta[NXP_NODE_FORMAT].is_channels_first(): |
96 | | - # Transform the shapes to channels last. |
97 | | - to_nhwc_perm = create_channels_first_to_channels_last_permutation( |
98 | | - len(node.meta["val"].shape), True |
99 | | - ) |
100 | | - input_shapes = [ |
101 | | - apply_permutation_to(shape, to_nhwc_perm) |
102 | | - for shape in input_shapes |
103 | | - ] |
104 | | - |
105 | | - # Transform the `dim` to refer to a channels last dimension. |
106 | | - dim = to_nhwc_perm.index(dim) |
107 | | - |
108 | | - for input_shape in input_shapes: |
109 | | - if not any(d != 1 for d in input_shape[:dim]): |
110 | | - # Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension. |
111 | | - return False |
112 | | - |
113 | 91 | # Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the |
114 | 92 | # last dimension, depending on the formats of the node. |
115 | 93 | if node.meta[NXP_NODE_FORMAT].is_channels_first(): |
@@ -172,6 +150,46 @@ def _is_supported_in_IR( |
172 | 150 |
|
173 | 151 | return True |
174 | 152 |
|
| 153 | + @classmethod |
| 154 | + def supports_partitioning_result( |
| 155 | + cls, |
| 156 | + node: Node, |
| 157 | + partition_list: list[Partition], |
| 158 | + custom_delegation_options: CustomDelegationOptions, |
| 159 | + ): |
| 160 | + # There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by |
| 161 | + # `dim` are `!= 1`, the `Concat` is not delegated. |
| 162 | + # This only happens when the inputs to the `Concat` are model inputs, and not outputs of other |
| 163 | + # operators. |
| 164 | + cat_partition = [p for p in partition_list if node in p.nodes][0] |
| 165 | + cat_inputs = map(previous_non_qdq_node, node.args[0]) |
| 166 | + |
| 167 | + if not all( |
| 168 | + input_.op == "call_function" and input_ in cat_partition.nodes |
| 169 | + for input_ in cat_inputs |
| 170 | + ): |
| 171 | + # Some inputs of the `cat` are NOT in the same partition as `cat`. |
| 172 | + dim = CatConverter._get_normalized_dim(node) |
| 173 | + input_shapes = [list(n.meta["val"].shape) for n in node.args[0]] |
| 174 | + if node.meta[NXP_NODE_FORMAT].is_channels_first(): |
| 175 | + # Transform the shapes to channels last. |
| 176 | + to_nhwc_perm = create_channels_first_to_channels_last_permutation( |
| 177 | + len(node.meta["val"].shape), True |
| 178 | + ) |
| 179 | + input_shapes = [ |
| 180 | + apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes |
| 181 | + ] |
| 182 | + |
| 183 | + # Transform the `dim` to refer to a channels last dimension. |
| 184 | + dim = to_nhwc_perm.index(dim) |
| 185 | + |
| 186 | + for input_shape in input_shapes: |
| 187 | + if not any(d != 1 for d in input_shape[:dim]): |
| 188 | + # Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension. |
| 189 | + return False |
| 190 | + |
| 191 | + return True |
| 192 | + |
175 | 193 | def convert(self, node: Node): |
176 | 194 | """Convert the 'aten.cat' operator to TFLite 'Concatenation'.""" |
177 | 195 | self.assert_convertible(node) |
|
0 commit comments