Skip to content

Commit 13773cd

Browse files
committed
NXP backend: Add context dependent cat partitioning.
1 parent efa0db8 commit 13773cd

File tree

3 files changed

+113
-32
lines changed

3 files changed

+113
-32
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@
2222
]
2323

2424

25-
def _is_dequantize(node_: Node) -> bool:
26-
return node_.op == "call_function" and node_.target in DEQUANTIZE_OPERATORS
27-
28-
29-
def _is_quantize(node_: Node) -> bool:
30-
return node_.op == "call_function" and node_.target in QUANTIZE_OPERATORS
31-
32-
3325
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
3426
if len(node.all_input_nodes) <= input_index:
3527
raise IndexError
@@ -103,3 +95,33 @@ def try_get_tensor_constant_from_node(
10395
return None
10496
attr_itr = getattr(attr_itr, atom)
10597
return attr_itr
98+
99+
100+
def _is_dequantize(node_: Node) -> bool:
101+
return node_.op == "call_function" and node_.target in [
102+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
103+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
104+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
105+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
106+
]
107+
108+
109+
def _is_quantize(node_: Node) -> bool:
110+
return node_.op == "call_function" and node_.target in [
111+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
112+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
113+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
114+
torch.ops.quantized_decomposed.quantize_per_channel.default,
115+
]
116+
117+
118+
def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
119+
"""Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards
120+
starting with the `node.args[input_index]`,
121+
"""
122+
current_node = node.args[input_index]
123+
while True:
124+
if _is_quantize(current_node) or _is_dequantize(current_node):
125+
current_node = current_node.args[0]
126+
else:
127+
return current_node

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from executorch.backends.nxp.backend.custom_delegation_options import (
99
CustomDelegationOptions,
1010
)
11+
from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node
1112
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1213
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
1314
apply_permutation_to,
@@ -24,6 +25,7 @@
2425
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2526
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2627
from torch.fx import Node
28+
from torch.fx.passes.infra.partitioner import Partition
2729
from torch.nn import Parameter
2830

2931

@@ -86,30 +88,6 @@ def _is_supported_on_target(
8688

8789
dim = CatConverter._get_normalized_dim(node)
8890

89-
# There is a bug in the NeutronConverter, where if none of the dimensions before the one referenced by
90-
# `dim` are `!= 1`, the `Concat` is not delegated.
91-
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
92-
# operators. However, such a requirement cannot be currently enforced during partitioning, so we must
93-
# take the conservative approach and decline delegation.
94-
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
95-
if node.meta[NXP_NODE_FORMAT].is_channels_first():
96-
# Transform the shapes to channels last.
97-
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
98-
len(node.meta["val"].shape), True
99-
)
100-
input_shapes = [
101-
apply_permutation_to(shape, to_nhwc_perm)
102-
for shape in input_shapes
103-
]
104-
105-
# Transform the `dim` to refer to a channels last dimension.
106-
dim = to_nhwc_perm.index(dim)
107-
108-
for input_shape in input_shapes:
109-
if not any(d != 1 for d in input_shape[:dim]):
110-
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
111-
return False
112-
11391
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
11492
# last dimension, depending on the formats of the node.
11593
if node.meta[NXP_NODE_FORMAT].is_channels_first():
@@ -172,6 +150,46 @@ def _is_supported_in_IR(
172150

173151
return True
174152

153+
@classmethod
154+
def supports_partitioning_result(
155+
cls,
156+
node: Node,
157+
partition_list: list[Partition],
158+
custom_delegation_options: CustomDelegationOptions,
159+
):
160+
# There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by
161+
# `dim` are `!= 1`, the `Concat` is not delegated.
162+
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
163+
# operators.
164+
cat_partition = [p for p in partition_list if node in p.nodes][0]
165+
cat_inputs = map(previous_non_qdq_node, node.args[0])
166+
167+
if not all(
168+
input_.op == "call_function" and input_ in cat_partition.nodes
169+
for input_ in cat_inputs
170+
):
171+
# Some inputs of the `cat` are NOT in the same partition as `cat`.
172+
dim = CatConverter._get_normalized_dim(node)
173+
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
174+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
175+
# Transform the shapes to channels last.
176+
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
177+
len(node.meta["val"].shape), True
178+
)
179+
input_shapes = [
180+
apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes
181+
]
182+
183+
# Transform the `dim` to refer to a channels last dimension.
184+
dim = to_nhwc_perm.index(dim)
185+
186+
for input_shape in input_shapes:
187+
if not any(d != 1 for d in input_shape[:dim]):
188+
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
189+
return False
190+
191+
return True
192+
175193
def convert(self, node: Node):
176194
"""Convert the 'aten.cat' operator to TFLite 'Concatenation'."""
177195
self.assert_convertible(node)

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def forward(self, *inputs: torch.Tensor):
4444
return torch.cat(list(inputs), self.dim)
4545

4646

47+
class AddCatModule(torch.nn.Module):
48+
49+
def __init__(self, dim: int):
50+
super().__init__()
51+
self.dim = dim
52+
53+
def forward(self, *inputs: torch.Tensor):
54+
inputs = [input_ + input_ for input_ in inputs]
55+
56+
return torch.cat(list(inputs), self.dim)
57+
58+
4759
class CatConvModule(torch.nn.Module):
4860

4961
def __init__(self, dim: int, channels: int = 4):
@@ -147,6 +159,9 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker):
147159
],
148160
)
149161
def test_cat__unsupported__imxrt700(dim, input_shape):
162+
"""This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`).
163+
In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated.
164+
"""
150165
num_inputs = 2
151166
quantized_program = to_quantized_edge_program(
152167
CatModule(dim), [input_shape] * num_inputs, target="imxrt700"
@@ -161,6 +176,32 @@ def test_cat__unsupported__imxrt700(dim, input_shape):
161176
)
162177

163178

179+
@pytest.mark.parametrize(
180+
"dim, input_shape",
181+
[
182+
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
183+
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
184+
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
185+
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
186+
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
187+
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
188+
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
189+
],
190+
)
191+
def test_cat__context_dependent__imxrt700(dim, input_shape):
192+
"""This test is conjoined with the one above (`test_cat__unsupported__imxrt700`).
193+
In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated.
194+
"""
195+
num_inputs = 2
196+
ep = to_quantized_edge_program(
197+
AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700"
198+
).exported_program()
199+
200+
# Make sure the `Cat` was delegated.
201+
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default])
202+
assert any("lowered_module" in node.name for node in ep.graph.nodes)
203+
204+
164205
@pytest.mark.parametrize(
165206
"rank, num_inputs, dim",
166207
[

0 commit comments

Comments
 (0)