Skip to content

Commit efa0db8

Browse files
committed
NXP backend: Fix delegation check for aten.cat.default.
1 parent 4ea9ddf commit efa0db8

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1212
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
13+
apply_permutation_to,
1314
create_channels_first_to_channels_last_permutation,
1415
)
1516
from executorch.backends.nxp.backend.ir.converter.node_converter import (
@@ -85,9 +86,29 @@ def _is_supported_on_target(
8586

8687
dim = CatConverter._get_normalized_dim(node)
8788

88-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491
89-
if dim == 0:
90-
return False
89+
# There is a bug in the NeutronConverter, where if none of the dimensions before the one referenced by
90+
# `dim` are `!= 1`, the `Concat` is not delegated.
91+
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
92+
# operators. However, such a requirement cannot be currently enforced during partitioning, so we must
93+
# take the conservative approach and decline delegation.
94+
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
95+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
96+
# Transform the shapes to channels last.
97+
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
98+
len(node.meta["val"].shape), True
99+
)
100+
input_shapes = [
101+
apply_permutation_to(shape, to_nhwc_perm)
102+
for shape in input_shapes
103+
]
104+
105+
# Transform the `dim` to refer to a channels last dimension.
106+
dim = to_nhwc_perm.index(dim)
107+
108+
for input_shape in input_shapes:
109+
if not any(d != 1 for d in input_shape[:dim]):
110+
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
111+
return False
91112

92113
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
93114
# last dimension, depending on the formats of the node.

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def forward(self, *inputs: torch.Tensor):
7373
],
7474
)
7575
def test_cat__same_shapes(dim, num_inputs, rank, mocker):
76-
input_shape = tuple([2, 8, 8, 8, 8][-rank:])
76+
input_shape = tuple([8, 8, 8, 8][:rank])
7777

7878
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
7979

@@ -134,11 +134,20 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker):
134134
)
135135

136136

137-
@pytest.mark.parametrize("dim", [0, -4])
138-
@pytest.mark.parametrize("num_inputs", [2])
139-
def test_cat__unsupported_dim__imxrt700(dim, num_inputs):
140-
input_shape = (2, 8, 6, 8)
141-
137+
@pytest.mark.parametrize(
138+
"dim, input_shape",
139+
[
140+
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
141+
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
142+
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
143+
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
144+
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
145+
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
146+
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
147+
],
148+
)
149+
def test_cat__unsupported__imxrt700(dim, input_shape):
150+
num_inputs = 2
142151
quantized_program = to_quantized_edge_program(
143152
CatModule(dim), [input_shape] * num_inputs, target="imxrt700"
144153
).exported_program()

0 commit comments

Comments
 (0)