Skip to content

Commit 4067d90

Browse files
committed
NXP backend: Add context dependent cat partitioning.
1 parent 67047f3 commit 4067d90

File tree

3 files changed

+115
-26
lines changed

3 files changed

+115
-26
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
79
from torch.fx import GraphModule, Node
810
from torch.nn import Parameter
911

@@ -87,3 +89,33 @@ def try_get_tensor_constant_from_node(
8789
return None
8890
attr_itr = getattr(attr_itr, atom)
8991
return attr_itr
92+
93+
94+
def _is_dequantize(node_: Node) -> bool:
95+
return node_.op == "call_function" and node_.target in [
96+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
97+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
98+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
99+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
100+
]
101+
102+
103+
def _is_quantize(node_: Node) -> bool:
104+
return node_.op == "call_function" and node_.target.__name__ in [
105+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
106+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
107+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
108+
torch.ops.quantized_decomposed.quantize_per_channel.default,
109+
]
110+
111+
112+
def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
113+
"""Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards
114+
starting with the `node.args[input_index]`,
115+
"""
116+
current_node = node.args[input_index]
117+
while True:
118+
if _is_quantize(current_node) or _is_dequantize(current_node):
119+
current_node = current_node.args[0]
120+
else:
121+
return current_node

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from executorch.backends.nxp.backend.custom_delegation_options import (
99
CustomDelegationOptions,
1010
)
11+
from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node
1112
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1213
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
1314
apply_permutation_to,
@@ -24,6 +25,7 @@
2425
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2526
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2627
from torch.fx import Node
28+
from torch.fx.passes.infra.partitioner import Partition
2729
from torch.nn import Parameter
2830

2931

@@ -84,32 +86,6 @@ def _is_supported_on_target(
8486
if custom_delegation_options.force_delegate_cat:
8587
return True
8688

87-
dim = CatConverter._get_normalized_dim(node)
88-
89-
# There is a bug in the NeutronConverter, where if none of the dimensions before the one referenced by
90-
# `dim` are `!= 1`, the `Concat` is not delegated.
91-
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
92-
# operators. However, such a requirement cannot be currently enforced during partitioning, so we must
93-
# take the conservative approach and decline delegation.
94-
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
95-
if node.meta[NXP_NODE_FORMAT].is_channels_first():
96-
# Transform the shapes to channels last.
97-
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
98-
len(node.meta["val"].shape), True
99-
)
100-
input_shapes = [
101-
apply_permutation_to(shape, to_nhwc_perm)
102-
for shape in input_shapes
103-
]
104-
105-
# Transform the `dim` to refer to a channels last dimension.
106-
dim = to_nhwc_perm.index(dim)
107-
108-
for input_shape in input_shapes:
109-
if not any(d != 1 for d in input_shape[:dim]):
110-
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
111-
return False
112-
11389
# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
11490
# last dimension, depending on the formats of the node.
11591
if node.meta[NXP_NODE_FORMAT].is_channels_first():
@@ -153,6 +129,46 @@ def _is_supported_in_IR(
153129

154130
return True
155131

132+
@classmethod
133+
def supports_partitioning_result(
134+
cls,
135+
node: Node,
136+
partition_list: list[Partition],
137+
custom_delegation_options: CustomDelegationOptions,
138+
):
139+
# There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by
140+
# `dim` are `!= 1`, the `Concat` is not delegated.
141+
# This only happens when the inputs to the `Concat` are model inputs, and not outputs of other
142+
# operators.
143+
cat_partition = [p for p in partition_list if node in p.nodes][0]
144+
cat_inputs = map(previous_non_qdq_node, node.args[0])
145+
146+
if not all(
147+
input_.op == "call_function" and input_ in cat_partition.nodes
148+
for input_ in cat_inputs
149+
):
150+
# Some inputs of the `cat` are NOT in the same partition as `cat`.
151+
dim = CatConverter._get_normalized_dim(node)
152+
input_shapes = [list(n.meta["val"].shape) for n in node.args[0]]
153+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
154+
# Transform the shapes to channels last.
155+
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
156+
len(node.meta["val"].shape), True
157+
)
158+
input_shapes = [
159+
apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes
160+
]
161+
162+
# Transform the `dim` to refer to a channels last dimension.
163+
dim = to_nhwc_perm.index(dim)
164+
165+
for input_shape in input_shapes:
166+
if not any(d != 1 for d in input_shape[:dim]):
167+
# Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension.
168+
return False
169+
170+
return True
171+
156172
def convert(self, node: Node):
157173
"""Convert the 'aten.cat' operator to TFLite 'Concatenation'."""
158174
self.assert_convertible(node)

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def forward(self, *inputs: torch.Tensor):
4444
return torch.cat(list(inputs), self.dim)
4545

4646

47+
class AddCatModule(torch.nn.Module):
48+
49+
def __init__(self, dim: int):
50+
super().__init__()
51+
self.dim = dim
52+
53+
def forward(self, *inputs: torch.Tensor):
54+
inputs = [input_ + input_ for input_ in inputs]
55+
56+
return torch.cat(list(inputs), self.dim)
57+
58+
4759
class CatConvModule(torch.nn.Module):
4860

4961
def __init__(self, dim: int, channels: int = 4):
@@ -147,6 +159,9 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker):
147159
],
148160
)
149161
def test_cat__unsupported__imxrt700(dim, input_shape):
162+
"""This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`).
163+
In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated.
164+
"""
150165
num_inputs = 2
151166
quantized_program = to_quantized_edge_program(
152167
CatModule(dim), [input_shape] * num_inputs, target="imxrt700"
@@ -161,6 +176,32 @@ def test_cat__unsupported__imxrt700(dim, input_shape):
161176
)
162177

163178

179+
@pytest.mark.parametrize(
180+
"dim, input_shape",
181+
[
182+
pytest.param(0, (1, 8, 8, 8), id="axis = 0"),
183+
pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."),
184+
pytest.param(-4, (1, 8, 8, 8), id="axis = -4"),
185+
pytest.param(1, (1, 1, 8, 8), id="axis = 1"),
186+
pytest.param(-3, (1, 1, 8, 8), id="axis = -3"),
187+
pytest.param(2, (1, 1, 1, 8), id="axis = 2"),
188+
pytest.param(-2, (1, 1, 1, 8), id="axis = -2"),
189+
],
190+
)
191+
def test_cat__context_dependent__imxrt700(dim, input_shape):
192+
"""This test is conjoined with the one above (`test_cat__unsupported__imxrt700`).
193+
In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated.
194+
"""
195+
num_inputs = 2
196+
ep = to_quantized_edge_program(
197+
AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700"
198+
).exported_program()
199+
200+
# Make sure the `Cat` was delegated.
201+
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default])
202+
assert any("lowered_module" in node.name for node in ep.graph.nodes)
203+
204+
164205
@pytest.mark.parametrize(
165206
"rank, num_inputs, dim",
166207
[

0 commit comments

Comments
 (0)