Skip to content

Commit 58f191b

Browse files
committed
NXP backend: Improve cat delegation by using inferred node formats.
1 parent e03b048 commit 58f191b

File tree

2 files changed

+93
-20
lines changed

2 files changed

+93
-20
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Concatenation,
1919
)
2020
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
21+
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
2122
from torch.fx import Node
2223
from torch.nn import Parameter
2324

@@ -85,32 +86,29 @@ def _is_supported_on_target(
8586
if dim == 0:
8687
return False
8788

88-
# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
89-
# last dimension, depending on the formats of the node. The format, however, cannot be determined
90-
# during conversion, as it depends on what other nodes are delegated.
89+
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
90+
# last dimension, depending on the formats of the node.
91+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
92+
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
93+
# `1` will end up being the channels (last dim in NHWC).
94+
channels_index = 1
95+
else:
96+
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
97+
channels_index = -1
98+
9199
input_channels = [
92-
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
93-
# will still be the channels in the IR.
94-
_get_shape(input_)[1]
95-
for input_ in node.all_input_nodes
96-
] + [
97-
# If the inputs/outputs are channels first, the last dimension will be the channels.
98-
_get_shape(input_)[-1]
100+
_get_shape(input_)[channels_index]
99101
for input_ in node.all_input_nodes
100102
]
101-
if any(
102-
(input_channel % neutron_target_spec.get_num_macs()) != 0
103-
for input_channel in input_channels
104-
):
103+
output_channels = _get_shape(node)[channels_index]
104+
105+
num_macs = neutron_target_spec.get_num_macs()
106+
if any((input_channel % num_macs) != 0 for input_channel in input_channels):
105107
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
106108
return False
107109

108-
output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
109-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
110-
if any(
111-
(out_c % neutron_target_spec.get_num_macs()) != 0
112-
for out_c in output_channels
113-
):
110+
if (output_channels % num_macs) != 0:
111+
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
114112
return False
115113

116114
if len(node.all_input_nodes) < 2: # Not supported on Neutron

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,78 @@ def test_cat__force_delegate():
296296
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
297297
)
298298
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
299+
300+
301+
def test_cat__format_specific_support__formatless(mocker):
302+
# The last dim will end up being the channels, as the format is `formatless`.
303+
# Only the last dim satisfies the Neutron requirements for the channels.
304+
input_shape = (3, 3, 3, 8)
305+
num_inputs = 2
306+
dim = 2
307+
308+
input_shapes = [input_shape] * num_inputs
309+
310+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
311+
312+
quantized_program = to_quantized_edge_program(
313+
CatModule(dim), input_shapes
314+
).exported_program()
315+
316+
# Make sure the `Cat` was delegated.
317+
assert not graph_contains_any_of_ops(
318+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
319+
)
320+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
321+
322+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
323+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
324+
input_data = {
325+
i: (np.random.random(shape) * 50).astype(np.int8)
326+
for i, shape in enumerate(input_shapes)
327+
}
328+
convert_run_compare(
329+
exported_program,
330+
tfl_model=tflite_flatbuffers_model,
331+
input_data=input_data,
332+
atol=1,
333+
)
334+
335+
336+
def test_cat__format_specific_support__channels_first(mocker):
337+
# The second dim will end up being the channels, as the format is `formatless`.
338+
# Only the second dim satisfies the Neutron requirements for the channels.
339+
input_shape = (3, 8, 3, 3)
340+
num_inputs = 2
341+
dim = 2
342+
343+
input_shapes = [input_shape] * num_inputs
344+
345+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
346+
347+
channels = (
348+
sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1]
349+
)
350+
quantized_program = to_quantized_edge_program(
351+
CatConvModule(dim, channels), input_shapes
352+
).exported_program()
353+
354+
# Make sure the `Cat` was delegated.
355+
assert not graph_contains_any_of_ops(
356+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
357+
)
358+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
359+
360+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
361+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
362+
input_data = {
363+
i: (np.random.random(shape) * 50).astype(np.int8)
364+
for i, shape in enumerate(input_shapes)
365+
}
366+
convert_run_compare(
367+
exported_program,
368+
tfl_model=tflite_flatbuffers_model,
369+
input_data=input_data,
370+
tflite_input_preprocess=ToNHWCPreprocess(),
371+
tflite_output_preprocess=ToNCHWPreprocess(),
372+
atol=1,
373+
)

0 commit comments

Comments
 (0)