diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 22ca258cd4f..96b40f30ced 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -85,6 +85,12 @@ def _is_supported_on_target( if dim == 0: return False + # If all input shapes are equal, the neutron is able to pad the last dimension of inputs and outputs. + input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes] + if input_shapes.count(input_shapes[0]) == len(input_shapes): + if dim == len(input_shapes[0]) - 1: + return True + # Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the # last dimension, depending on the formats of the node. The format, however, cannot be determined # during conversion, as it depends on what other nodes are delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 3df703f5bba..1a003f9b685 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -290,3 +290,46 @@ def test_cat__force_delegate(): graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] ) assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__same_shapes_converter_padding_last_dimension(): + target = "imxrt700" + + # The Converter is capable of padding the last dimension of `cat` with the same input shapes. + input_shape = (3, 1, 3) + + quantized_program = to_quantized_edge_program( + CatModule(2), + [input_shape, input_shape], + target=target, + neutron_converter_flavor="SDK_25_09", + custom_delegation_options=CustomDelegationOptions(), + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__same_shapes_converter_padding_middle_dimension(): + target = "imxrt700" + + # The Converter is not capable of padding the middle dimensions of `cat` with the same input shapes. + input_shape = (3, 1, 3) + + quantized_program = to_quantized_edge_program( + CatModule(1), + [input_shape, input_shape], + target=target, + custom_delegation_options=CustomDelegationOptions(), + ).exported_program() + + # Make sure the `Cat` was NOT delegated. + assert graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert not any( + "lowered_module" in node.name for node in quantized_program.graph.nodes + )