Skip to content

Commit 4547cca

Browse files
committed
NXP backend: Add tensor format support for permute_copy.
1 parent 83cd7ca commit 4547cca

File tree

4 files changed

+450
-33
lines changed

4 files changed

+450
-33
lines changed

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ def builder(self) -> AtenModelBuilderDirector:
185185
"""
186186
return self.context.tflite_builder
187187

188+
@property
189+
def neutron_target_spec(self) -> NeutronTargetSpec:
190+
"""
191+
Get an instance of NeutronTargetSpec from the conversion context.
192+
:return: NeutronTargetSpec instance.
193+
"""
194+
return self.builder.neutron_target_spec
195+
188196
def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator:
189197
"""
190198
Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'.

backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py

Lines changed: 282 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,27 @@
66
import numpy as np
77
import torch
88

9+
from executorch.backends.nxp.backend.edge_helper import (
10+
node_is_effectively_static_tensor,
11+
)
912
from executorch.backends.nxp.backend.ir.converter import quantization_utils
13+
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1014
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1115
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1216
CustomDelegationOptions,
1317
NeutronTargetSpec,
1418
NodeConverter,
1519
)
20+
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
21+
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
1622
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1723
transpose_options,
1824
)
1925
from executorch.backends.nxp.backend.neutron_operator_support import (
26+
is_tensor_invariant_permutation,
2027
transposition_is_supported_on_neutron,
2128
)
29+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2230
from torch.fx import Node
2331
from torch.nn import Parameter
2432

@@ -35,14 +43,111 @@ def _is_supported_on_target(
3543
parameters_mapping: dict[str, Parameter],
3644
custom_delegation_options: CustomDelegationOptions,
3745
) -> bool:
46+
if node_is_effectively_static_tensor(node.args[0], parameters_mapping):
47+
return (
48+
True # The operator computes on static data. It will be removed later.
49+
)
50+
3851
input_shape = _get_shape(node.args[0])
39-
permutation = list(node.args[1])
52+
perm = list(node.args[1])
53+
output_shape = _get_shape(node)
54+
55+
# Since ExecuTorch and NeutronIR use different tensor formats, we must consider the different possible cases
56+
# which may occur. The main permutation is always done on channels_first/formatless data, and the output is
57+
# channels_first/formatless as well. If this is not the case, a `Transpose` is inserted before and/or
58+
# after the main `Transpose`, to make the input/output channels_first. These additional `Transpose`
59+
# ops must be supported by Neutron as well. Alternatively, consecutive `Transpose` ops can be fused
60+
# together. It is possible for a pair of unsupported permutation to result in a supported one.
61+
# Therefore, the merged permutations must also be considered.
62+
to_nchw_perm = translator.create_channels_last_to_channels_first_permutation(
63+
len(input_shape), True
64+
)
65+
to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation(
66+
len(input_shape), True
67+
)
68+
channels_last_input_shape = translator.apply_permutation_to(
69+
input_shape, to_nhwc_perm
70+
)
71+
72+
if is_tensor_invariant_permutation(
73+
input_shape, perm
74+
) and is_tensor_invariant_permutation(channels_last_input_shape, perm):
75+
# The `permute_copy` can always be represented as a Reshape.
76+
return True
4077

41-
# TODO Handle tensor formats properly.
78+
main_perm_supported = transposition_is_supported_on_neutron(
79+
input_shape, perm, neutron_target_spec
80+
)
81+
# "To NCHW" permutation, in case the input is channels last.
82+
separate_pre_transpose_supported = transposition_is_supported_on_neutron(
83+
channels_last_input_shape, to_nchw_perm, neutron_target_spec
84+
)
85+
# The main permutation and the previous one merged.
86+
merged_pre_transpose_supported = transposition_is_supported_on_neutron(
87+
channels_last_input_shape,
88+
translator.combine_permutations(to_nchw_perm, perm),
89+
neutron_target_spec,
90+
)
91+
# "To NHWC" permutation after the main `Transpose`.
92+
separate_post_transpose_supported = transposition_is_supported_on_neutron(
93+
output_shape, to_nhwc_perm, neutron_target_spec
94+
)
95+
# The main permutation and the previous one merged.
96+
merged_post_transpose_supported = transposition_is_supported_on_neutron(
97+
input_shape,
98+
translator.combine_permutations(perm, to_nhwc_perm),
99+
neutron_target_spec,
100+
)
101+
# "To NCHW", main permutation, and "to NHWC" all merged.
102+
everything_merged_supported = transposition_is_supported_on_neutron(
103+
input_shape,
104+
translator.combine_permutations(
105+
translator.combine_permutations(to_nchw_perm, perm), to_nhwc_perm
106+
),
107+
neutron_target_spec,
108+
)
42109

43-
return transposition_is_supported_on_neutron(
44-
input_shape, permutation, neutron_target_spec
110+
input_format, output_format = (
111+
node.args[0].meta[NXP_NODE_FORMAT],
112+
node.meta[NXP_NODE_FORMAT],
45113
)
114+
if input_format.is_channels_first() and (not output_format.is_channels_first()):
115+
# Just the input must be permuted.
116+
return (
117+
separate_pre_transpose_supported and main_perm_supported
118+
) or merged_pre_transpose_supported
119+
120+
elif (
121+
not input_format.is_channels_first()
122+
) and output_format.is_channels_first():
123+
# Just the output must be permuted.
124+
return (
125+
separate_post_transpose_supported and main_perm_supported
126+
) or merged_post_transpose_supported
127+
128+
elif input_format.is_channels_first() and output_format.is_channels_first():
129+
# Both input and output must be permuted.
130+
return (
131+
# Separate IO transpositions.
132+
(
133+
separate_pre_transpose_supported
134+
and main_perm_supported
135+
and separate_post_transpose_supported
136+
)
137+
# Separate input, merged output.
138+
or (
139+
separate_pre_transpose_supported and merged_post_transpose_supported
140+
)
141+
# Merged input, separate output.
142+
or (
143+
merged_pre_transpose_supported and separate_post_transpose_supported
144+
)
145+
# Merged input and output.
146+
or everything_merged_supported
147+
)
148+
else:
149+
# Simplest case. No format changes required.
150+
return main_perm_supported
46151

47152
@staticmethod
48153
def _is_supported_in_IR(
@@ -55,6 +160,177 @@ def _is_supported_in_IR(
55160

56161
return True
57162

163+
def handle_tensor_formats(self, t_op: tflite_model.Operator, node: Node) -> OpsList:
164+
"""Due to the different tensor formats used by ExecuTorch and NeutronIR, it may be necessary to modify the
165+
permutation, or insert extra permutations to equalize the tensor formats.
166+
This method identifies the four possible cases of input/output formats, and finds the conversion solution
167+
which minimizes the number of necessary `Transpose` operators.
168+
"""
169+
input_shape = node.args[0].meta["val"].shape
170+
output_shape = node.meta["val"].shape
171+
perm = list(node.args[1])
172+
173+
to_nchw_perm = translator.create_channels_last_to_channels_first_permutation(
174+
len(input_shape), True
175+
)
176+
to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation(
177+
len(input_shape), True
178+
)
179+
channels_last_input_shape = translator.apply_permutation_to(
180+
input_shape, to_nhwc_perm
181+
)
182+
183+
main_perm_supported = transposition_is_supported_on_neutron(
184+
input_shape, perm, self.neutron_target_spec
185+
)
186+
187+
# "To NCHW" permutation, in case the input is channels last.
188+
separate_pre_transpose_supported = transposition_is_supported_on_neutron(
189+
channels_last_input_shape, to_nchw_perm, self.neutron_target_spec
190+
)
191+
# The main permutation and the previous one merged.
192+
merged_pre_transpose_supported = transposition_is_supported_on_neutron(
193+
channels_last_input_shape,
194+
merged_pre_transpose_permutation := translator.combine_permutations(
195+
to_nchw_perm, perm
196+
),
197+
self.neutron_target_spec,
198+
)
199+
200+
# "To NHWC" permutation after the main `Transpose`.
201+
separate_post_transpose_supported = transposition_is_supported_on_neutron(
202+
output_shape, to_nhwc_perm, self.neutron_target_spec
203+
)
204+
205+
# The main permutation and the previous one merged.
206+
merged_post_transpose_supported = transposition_is_supported_on_neutron(
207+
input_shape,
208+
merged_post_transpose_permutation := translator.combine_permutations(
209+
perm, to_nhwc_perm
210+
),
211+
self.neutron_target_spec,
212+
)
213+
214+
# "To NCHW", main permutation, and "to NHWC" all merged.
215+
everything_merged_supported = transposition_is_supported_on_neutron(
216+
input_shape,
217+
everything_merged_permutation := translator.combine_permutations(
218+
translator.combine_permutations(to_nchw_perm, perm), to_nhwc_perm
219+
),
220+
self.neutron_target_spec,
221+
)
222+
223+
ops = OpsList(middle_op=t_op)
224+
input_format, output_format = (
225+
node.args[0].meta[NXP_NODE_FORMAT],
226+
node.meta[NXP_NODE_FORMAT],
227+
)
228+
if input_format.is_channels_first() and (not output_format.is_channels_first()):
229+
# The input must be permuted.
230+
# Either combine the permutations, or prepend a `Transpose` operator.
231+
if merged_pre_transpose_supported:
232+
# Use the combined permutation.
233+
perm = merged_pre_transpose_permutation
234+
elif separate_pre_transpose_supported and main_perm_supported:
235+
# Prepend a `Transpose` operator to make the input channels first.
236+
ops.add_pre(
237+
self.builder.create_transpose_operator_before(t_op, 0, to_nchw_perm)
238+
)
239+
elif not node_is_effectively_static_tensor(
240+
node.args[0], self.context.parameters_mapping
241+
):
242+
# The `permute_copy` cannot be represented in Neutron. This should never happen.
243+
raise RuntimeError(
244+
"A `permute_copy` node was incorrectly selected for delegation. Please report this."
245+
)
246+
247+
t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST
248+
249+
elif (
250+
not input_format.is_channels_first()
251+
) and output_format.is_channels_first():
252+
# The output must be permuted.
253+
# Either combine the permutations, or append a `Transpose` operator.
254+
if merged_post_transpose_supported:
255+
# Use the combined permutation.
256+
perm = merged_post_transpose_permutation
257+
elif main_perm_supported and separate_post_transpose_supported:
258+
# Append a `Transpose` operator to make the output channels first.
259+
ops.add_post(
260+
self.builder.create_transpose_operator_after(t_op, 0, to_nhwc_perm)
261+
)
262+
elif not node_is_effectively_static_tensor(
263+
node.args[0], self.context.parameters_mapping
264+
):
265+
# The `permute_copy` cannot be represented in Neutron. This should never happen.
266+
raise RuntimeError(
267+
"A `permute_copy` node was incorrectly selected for delegation. Please report this."
268+
)
269+
270+
t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST
271+
272+
elif input_format.is_channels_first() and output_format.is_channels_first():
273+
# Both input and output must be permuted, or some merged permutations must be supported.
274+
if everything_merged_supported:
275+
# Combine all 3 permutations into 1.
276+
perm = everything_merged_permutation
277+
elif merged_pre_transpose_supported and separate_post_transpose_supported:
278+
# Combine the input and main permutations, and append a `Transpose` to handle the output permutation.
279+
perm = merged_pre_transpose_permutation
280+
ops.add_post(
281+
self.builder.create_transpose_operator_after(t_op, 0, to_nhwc_perm)
282+
)
283+
elif separate_pre_transpose_supported and merged_post_transpose_supported:
284+
# Prepend a `Transpose` to handle the input permutation, and combine the main and output permutations.
285+
ops.add_pre(
286+
self.builder.create_transpose_operator_before(t_op, 0, to_nchw_perm)
287+
)
288+
perm = merged_post_transpose_permutation
289+
elif (
290+
separate_pre_transpose_supported
291+
and main_perm_supported
292+
and separate_post_transpose_supported
293+
):
294+
# Handle each permutation separately.
295+
ops.add_pre(
296+
self.builder.create_transpose_operator_before(t_op, 0, to_nchw_perm)
297+
)
298+
perm = perm # The main permutation remains unchanged.
299+
ops.add_post(
300+
self.builder.create_transpose_operator_after(t_op, 0, to_nhwc_perm)
301+
)
302+
elif not node_is_effectively_static_tensor(
303+
node.args[0], self.context.parameters_mapping
304+
):
305+
# The `permute_copy` cannot be represented in Neutron. This should never happen.
306+
raise RuntimeError(
307+
"A `permute_copy` node was incorrectly selected for delegation. Please report this."
308+
)
309+
310+
t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST
311+
t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST
312+
313+
else:
314+
# Neither the input nor the output have to be permuted.
315+
if main_perm_supported:
316+
perm = perm # The main permutation remains unchanged.
317+
elif not node_is_effectively_static_tensor(
318+
node.args[0], self.context.parameters_mapping
319+
):
320+
# The `permute_copy` cannot be represented in Neutron. This should never happen.
321+
raise RuntimeError(
322+
"A `permute_copy` node was incorrectly selected for delegation. Please report this."
323+
)
324+
325+
perm_tensor = self.builder.create_tensor_for_data(
326+
np.array(perm, "int32"), "perm"
327+
)
328+
329+
# Use the final permutation as the operator's second input.
330+
t_op.tmp_inputs = [t_op.tmp_inputs[0], perm_tensor]
331+
332+
return ops
333+
58334
def convert(self, node: Node):
59335
"""Convert the `aten.permute_copy` operator to TFLite `Transpose`."""
60336
self.assert_convertible(node)
@@ -80,13 +356,6 @@ def convert(self, node: Node):
80356
"match. This indicates error in quantizer."
81357
)
82358

83-
perm = np.array(node.args[1], "int32")
84-
perm_tensor = self.builder.create_tensor_for_data(perm, "perm")
85-
86-
# Assign the operator its TFLite inputs and outputs
87-
t_op.tmp_inputs = [x, perm_tensor]
88-
t_op.tmp_outputs = [y]
89-
90-
ops_to_add = OpsList(middle_op=t_op)
359+
ops = self.handle_tensor_formats(t_op, node)
91360

92-
self.builder.append_operators(ops_to_add.flatten())
361+
self.builder.append_operators(ops.flatten())

backends/nxp/backend/node_format_inference.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ class NodeFormatInference:
3030

3131
# A set of Edge Aten ops, which have the ability to change the format (for example - input nodes
3232
# are channels first but output is formatless).
33-
ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default}
33+
ops_that_can_change_tensor_format = {
34+
exir_ops.edge.aten.view_copy.default,
35+
exir_ops.edge.aten.permute_copy.default,
36+
}
3437

3538
_type_changed_during_last_run: bool
3639

@@ -88,11 +91,23 @@ def _infer_format_of_nodes(self, node: Node):
8891

8992
if op_type in self.ops_with_channels_first_nodes:
9093
self._handle_node_which_uses_channels_first_format(node)
94+
9195
elif op_type in self.ops_that_can_change_tensor_format:
92-
if op_type == exir_ops.edge.aten.view_copy.default: # view_copy
96+
if op_type in [
97+
exir_ops.edge.aten.view_copy.default,
98+
exir_ops.edge.aten.permute_copy.default,
99+
]:
100+
# Try to assign the `formatless` format to the input and output. The converter will then handle the
101+
# transition.
102+
# Note: If the format for the input/output has already been assigned as channels first, it will NOT be
103+
# overwritten.
93104
self._assign_format_to_node(
94105
self._node_outputs[node][0], NodeFormat.FORMATLESS
95106
)
107+
self._assign_format_to_node(
108+
self._node_inputs[node][0], NodeFormat.FORMATLESS
109+
)
110+
96111
else:
97112
logger.error(
98113
f"Node format inference for node type: {op_type} not found!"

0 commit comments

Comments
 (0)