Skip to content

Commit 2ca11c2

Browse files
committed
NXP backend: Improve cat delegation by using inferred node formats.
1 parent 0b8ffda commit 2ca11c2

File tree

2 files changed

+137
-28
lines changed

2 files changed

+137
-28
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
CustomDelegationOptions,
1010
)
1111
from executorch.backends.nxp.backend.ir.converter.conversion import translator
12+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
13+
create_channels_first_to_channels_last_permutation,
14+
)
1215
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1316
_is_dequant_node,
1417
_is_quant_node,
@@ -18,6 +21,7 @@
1821
Concatenation,
1922
)
2023
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
24+
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
2125
from torch.fx import Node
2226
from torch.nn import Parameter
2327

@@ -85,39 +89,48 @@ def _is_supported_on_target(
8589
if dim == 0:
8690
return False
8791

88-
# If all input shapes are equal, the neutron is able to pad the last dimension of inputs and outputs.
89-
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
90-
if input_shapes.count(input_shapes[0]) == len(input_shapes):
91-
if dim == len(input_shapes[0]) - 1:
92-
return True
92+
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
93+
# last dimension, depending on the formats of the node.
94+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
95+
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
96+
# `1` will end up being the channels (last dim in NHWC).
97+
channels_index = 1
98+
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
99+
len(node.meta["val"].shape), True
100+
)
101+
dim = to_nhwc_perm.index(
102+
dim
103+
) # Make sure the dim points to the NHWC dimension.
104+
else:
105+
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
106+
channels_index = -1
93107

94-
# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
95-
# last dimension, depending on the formats of the node. The format, however, cannot be determined
96-
# during conversion, as it depends on what other nodes are delegated.
97108
input_channels = [
98-
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
99-
# will still be the channels in the IR.
100-
_get_shape(input_)[1]
101-
for input_ in node.all_input_nodes
102-
] + [
103-
# If the inputs/outputs are channels first, the last dimension will be the channels.
104-
_get_shape(input_)[-1]
105-
for input_ in node.all_input_nodes
109+
_get_shape(input_)[channels_index] for input_ in node.all_input_nodes
106110
]
107-
if any(
108-
(input_channel % neutron_target_spec.get_num_macs()) != 0
109-
for input_channel in input_channels
110-
):
111+
output_channels = _get_shape(node)[channels_index]
112+
113+
num_macs = neutron_target_spec.get_num_macs()
114+
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
115+
if any((input_channel % num_macs) != 0 for input_channel in input_channels):
111116
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
112-
return False
113117

114-
output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
115-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
116-
if any(
117-
(out_c % neutron_target_spec.get_num_macs()) != 0
118-
for out_c in output_channels
119-
):
120-
return False
118+
# If all input shapes are equal, the neutron is able to pad the last dimension of the inputs.
119+
if not (
120+
input_shapes.count(input_shapes[0]) == len(input_shapes)
121+
and dim == len(input_shapes[0]) - 1
122+
):
123+
return False
124+
125+
if (output_channels % num_macs) != 0:
126+
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
127+
128+
# If all input shapes are equal, the neutron is able to pad the last dimension of the output.
129+
if not (
130+
input_shapes.count(input_shapes[0]) == len(input_shapes)
131+
and dim == len(input_shapes[0]) - 1
132+
):
133+
return False
121134

122135
if len(node.all_input_nodes) < 2: # Not supported on Neutron
123136
# TODO Try to skip the operator if this case is realistic.

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,27 @@ def test_cat__same_shapes_converter_padding_last_dimension():
319319
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
320320

321321

322+
def test_cat__same_shapes__channels_first__padding_channels():
323+
target = "imxrt700"
324+
325+
# The Converter is capable of padding the last dimension of `cat` with the same input shapes.
326+
input_shape = (1, 2, 3, 4)
327+
328+
quantized_program = to_quantized_edge_program(
329+
CatConvModule(1),
330+
[input_shape, input_shape],
331+
target=target,
332+
neutron_converter_flavor="SDK_25_09",
333+
custom_delegation_options=CustomDelegationOptions(),
334+
).exported_program()
335+
336+
# Make sure the `Cat` was delegated.
337+
assert not graph_contains_any_of_ops(
338+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
339+
)
340+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
341+
342+
322343
def test_cat__same_shapes_converter_padding_middle_dimension():
323344
target = "imxrt700"
324345

@@ -339,3 +360,78 @@ def test_cat__same_shapes_converter_padding_middle_dimension():
339360
assert not any(
340361
"lowered_module" in node.name for node in quantized_program.graph.nodes
341362
)
363+
364+
365+
def test_cat__format_specific_support__formatless(mocker):
366+
# The last dim will end up being the channels, as the format is `formatless`.
367+
# Only the last dim satisfies the Neutron requirements for the channels.
368+
input_shape = (3, 3, 3, 8)
369+
num_inputs = 2
370+
dim = 2
371+
372+
input_shapes = [input_shape] * num_inputs
373+
374+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
375+
376+
quantized_program = to_quantized_edge_program(
377+
CatModule(dim), input_shapes
378+
).exported_program()
379+
380+
# Make sure the `Cat` was delegated.
381+
assert not graph_contains_any_of_ops(
382+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
383+
)
384+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
385+
386+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
387+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
388+
input_data = {
389+
i: (np.random.random(shape) * 50).astype(np.int8)
390+
for i, shape in enumerate(input_shapes)
391+
}
392+
convert_run_compare(
393+
exported_program,
394+
tfl_model=tflite_flatbuffers_model,
395+
input_data=input_data,
396+
atol=1,
397+
)
398+
399+
400+
def test_cat__format_specific_support__channels_first(mocker):
401+
# The second dim will end up being the channels, as the format is `formatless`.
402+
# Only the second dim satisfies the Neutron requirements for the channels.
403+
input_shape = (3, 8, 3, 3)
404+
num_inputs = 2
405+
dim = 2
406+
407+
input_shapes = [input_shape] * num_inputs
408+
409+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
410+
411+
channels = (
412+
sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1]
413+
)
414+
quantized_program = to_quantized_edge_program(
415+
CatConvModule(dim, channels), input_shapes
416+
).exported_program()
417+
418+
# Make sure the `Cat` was delegated.
419+
assert not graph_contains_any_of_ops(
420+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
421+
)
422+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
423+
424+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
425+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
426+
input_data = {
427+
i: (np.random.random(shape) * 50).astype(np.int8)
428+
for i, shape in enumerate(input_shapes)
429+
}
430+
convert_run_compare(
431+
exported_program,
432+
tfl_model=tflite_flatbuffers_model,
433+
input_data=input_data,
434+
tflite_input_preprocess=ToNHWCPreprocess(),
435+
tflite_output_preprocess=ToNCHWPreprocess(),
436+
atol=1,
437+
)

0 commit comments

Comments
 (0)