Skip to content

Commit 1f114f1

Browse files
committed
The padding of concat is not needed anymore if the inputs are equal.
1 parent bf7d755 commit 1f114f1

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def _is_supported_on_target(
8585
if dim == 0:
8686
return False
8787

88+
# If all input shapes are equal, the neutron is able to pad the last dimension of inputs and outputs.
89+
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
90+
if input_shapes.count(input_shapes[0]) == len(input_shapes):
91+
if dim == len(input_shapes[0]) - 1:
92+
return True
93+
8894
# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
8995
# last dimension, depending on the formats of the node. The format, however, cannot be determined
9096
# during conversion, as it depends on what other nodes are delegated.

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,46 @@ def test_cat__force_delegate():
290290
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
291291
)
292292
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
293+
294+
295+
def test_cat__same_shapes_converter_padding_last_dimension():
296+
target = "imxrt700"
297+
298+
# The Converter is capable of padding the last dimension of `cat` with the same input shapes.
299+
input_shape = (3, 1, 3)
300+
301+
quantized_program = to_quantized_edge_program(
302+
CatModule(2),
303+
[input_shape, input_shape],
304+
target=target,
305+
neutron_converter_flavor="SDK_25_09",
306+
custom_delegation_options=CustomDelegationOptions(),
307+
).exported_program()
308+
309+
# Make sure the `Cat` was delegated.
310+
assert not graph_contains_any_of_ops(
311+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
312+
)
313+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
314+
315+
316+
def test_cat__same_shapes_converter_padding_middle_dimension():
317+
target = "imxrt700"
318+
319+
# The Converter is not capable of padding the middle dimensions of `cat` with the same input shapes.
320+
input_shape = (3, 1, 3)
321+
322+
quantized_program = to_quantized_edge_program(
323+
CatModule(1),
324+
[input_shape, input_shape],
325+
target=target,
326+
custom_delegation_options=CustomDelegationOptions(),
327+
).exported_program()
328+
329+
# Make sure the `Cat` was NOT delegated.
330+
assert graph_contains_any_of_ops(
331+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
332+
)
333+
assert not any(
334+
"lowered_module" in node.name for node in quantized_program.graph.nodes
335+
)

0 commit comments

Comments
 (0)