Skip to content

Commit 4e91b53

Browse files
committed
NXP backend: Add support for aten.clone with contiguous memory format.
This node is sometimes added into a QDQ cluster after lowering to edge, if a tensor has some specific memory format which is not supported by the following node.
1 parent 3374ff8 commit 4e91b53

File tree

4 files changed

+223
-50
lines changed

4 files changed

+223
-50
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
78
from executorch.backends.nxp.backend.ir.converter.node_converter import (
89
CustomDelegationOptions,
910
NodeConverter,
@@ -13,10 +14,31 @@
1314

1415

1516
def _has_supported_memory_format(node: Node) -> bool:
16-
if "memory_format" in node.kwargs.keys():
17-
return node.kwargs["memory_format"] == torch.preserve_format
17+
"""The node can either represent an `aten.clone` or a `dim_order_ops._clone_dim_order` operator."""
18+
memory_format = node.kwargs.get("memory_format", None) # Attribute of `aten.clone`.
19+
dim_order = node.kwargs.get(
20+
"dim_order", None
21+
) # Attribute of `dim_order_ops._clone_dim_order`.
22+
23+
if (memory_format, dim_order) == (torch.preserve_format, None):
24+
# The operator does nothing (e.g. originated as a `Dropout`).
25+
return True
26+
27+
contiguous_dim_order = list(range(len(node.meta["val"].shape)))
28+
if (memory_format, dim_order) in [
29+
(torch.contiguous_format, None),
30+
(None, contiguous_dim_order),
31+
]:
32+
# Sometimes there is a `permute_copy` (Transpose) in Executorch, which doesn't actually permute the data in
33+
# memory. Instead, it just changes the `strides` (memory format) to match the permutation. Then, some
34+
# following operator may or may not support the particular strides (e.g. `mul` supports anything but
35+
# `view_copy` does not), so the `clone` may be inserted to actually permute the data in memory to the
36+
# `contiguous` format. This is purely an Executorch issue, and there is no equivalent system in NeutronIR.
37+
# In NeutronIR, every tensor is stored in memory exactly as its shape suggests. Therefore, the `clone` can
38+
# simply be omitted.
39+
return True
1840

19-
return True
41+
return False
2042

2143

2244
class CloneConverter(NodeConverter):

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
Relu = exir_ops.edge.aten.relu.default
2121
Sigmoid = exir_ops.edge.aten.sigmoid.default
2222
Tanh = exir_ops.edge.aten.tanh.default
23+
Clone = exir_ops.edge.aten.clone.default
24+
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default
2325

2426

2527
def insert_qdq_pair_after_node(
@@ -69,29 +71,29 @@ def _is_quantize(node_: Node) -> bool:
6971

7072
class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
7173
"""
72-
73-
┌─────▼──────┐
74-
│ │ dequantize │
75-
┌─────▼──────┐ └─────┬──────┘
76-
│ dequantize │ ┌─────▼──────┐
77-
└─────┬──────┘ │ <aux_node> │
78-
┌─────▼──────┐ └─────┬──────┘
79-
│ <aux_node> │ ┌────▼─────┐ ┐
80-
└─────┬──────┘ │ quantize │ │
81-
┌──────────▼──────────┐ replaced with └────┬─────┘ │
82-
┤ <main_cluster_node> ├ ──────────────► │ │ newly added nodes
83-
└──────────┬──────────┘ ┌─────▼──────┐ │
84-
▼ │ dequantize │ │
85-
└─────┬──────┘ ┘
86-
┌────▼─────┐ ┌──────────▼──────────┐
87-
│ quantize │ ┤ <main_cluster_node> ├
88-
└────┬─────┘ └──────────┬──────────┘
89-
▼ ▼
90-
91-
┌────▼─────┐
92-
│ quantize │
93-
└────┬─────┘
94-
74+
75+
┌─────▼──────┐
76+
│ │ dequantize │
77+
┌─────▼──────┐ └─────┬──────┘
78+
│ dequantize │ ┌─────▼──────┐
79+
└─────┬──────┘ │ <aux_node> │
80+
┌─────▼──────┐ └─────┬──────┘
81+
│ <aux_node> │ ┌────▼─────┐ ┐
82+
└─────┬──────┘ │ quantize │ │
83+
┌──────────▼──────────┐ replaced with └────┬─────┘ │
84+
...┤ <main_cluster_node> ├... ──────────────► │ │ newly added nodes
85+
└──────────┬──────────┘ ┌─────▼──────┐ │
86+
▼ │ dequantize │ │
87+
. └─────┬──────┘ ┘
88+
┌────▼─────┐ ┌──────────▼──────────┐
89+
│ quantize │ ...┤ <main_cluster_node> ├...
90+
└────┬─────┘ └──────────┬──────────┘
91+
▼ ▼
92+
.
93+
┌────▼─────┐
94+
│ quantize │
95+
└────┬─────┘
96+
9597
"""
9698

9799
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
@@ -102,6 +104,7 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
102104
MM: [
103105
ViewCopy,
104106
],
107+
ViewCopy: [Clone, CloneDimOrder],
105108
}
106109

107110
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
@@ -152,28 +155,28 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
152155

153156
class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
154157
"""
155-
156-
┌─────▼──────┐
157-
│ │ dequantize │
158-
┌─────▼──────┐ └─────┬──────┘
159-
│ dequantize │
160-
└─────┬──────┘ ┌──────────▼──────────┐
161-
┤ <main_cluster_node> ├
162-
└──────────┬──────────┘
163-
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
164-
┤ <main_cluster_node> ├ ──────────────► │ quantize │ │
165-
└──────────┬──────────┘ └────┬─────┘ │
166-
┌─────▼──────┐ │ │ newly added nodes
167-
│ <aux_node> │ ┌─────▼──────┐ │
168-
└─────┬──────┘ │ dequantize │ │
169-
┌────▼─────┐ └─────┬──────┘ ┘
170-
│ quantize │ ┌─────▼──────┐
171-
└────┬─────┘ │ <aux_node> │
172-
▼ └─────┬──────┘
173-
┌────▼─────┐
174-
│ quantize │
175-
└────┬─────┘
176-
158+
159+
┌─────▼──────┐
160+
│ │ dequantize │
161+
┌─────▼──────┐ └─────┬──────┘
162+
│ dequantize │ .
163+
└─────┬──────┘ ┌──────────▼──────────┐
164+
...┤ <main_cluster_node> ├...
165+
. └──────────┬──────────┘
166+
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
167+
...┤ <main_cluster_node> ├... ──────────────► │ quantize │ │
168+
└──────────┬──────────┘ └────┬─────┘ │
169+
┌─────▼──────┐ │ │ newly added nodes
170+
│ <aux_node> │ ┌─────▼──────┐ │
171+
└─────┬──────┘ │ dequantize │ │
172+
┌────▼─────┐ └─────┬──────┘ ┘
173+
│ quantize │ ┌─────▼──────┐
174+
└────┬─────┘ │ <aux_node> │
175+
▼ └─────┬──────┘
176+
┌────▼─────┐
177+
│ quantize │
178+
└────┬─────┘
179+
177180
"""
178181

179182
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
@@ -198,6 +201,7 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
198201
Sigmoid,
199202
Tanh,
200203
],
204+
ViewCopy: [Clone, CloneDimOrder],
201205
}
202206

203207
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:

backends/nxp/neutron_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class QDQCluster:
7979
exir_ops.edge.aten.relu.default,
8080
exir_ops.edge.aten.sigmoid.default,
8181
exir_ops.edge.aten.tanh.default,
82+
exir_ops.edge.aten.clone.default,
83+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
8284
]
8385

8486
def __init__(self):

backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
6-
75
import itertools
86
import unittest
97

@@ -14,18 +12,41 @@
1412
from executorch.backends.nxp.backend.edge_program_converter import (
1513
EdgeProgramToIRConverter,
1614
)
15+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import (
16+
PermuteCopyConverter,
17+
)
18+
from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference
19+
from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
20+
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
21+
)
22+
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
23+
NeutronEdgePassManager,
24+
)
25+
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
26+
RemoveIOQuantOpsPass,
27+
)
28+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
29+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
30+
from executorch.backends.nxp.quantizer.utils import post_training_quantize
31+
from executorch.backends.nxp.tests import executors
1732
from executorch.backends.nxp.tests.executorch_pipeline import (
33+
get_random_calibration_inputs,
34+
neutron_target_spec,
1835
to_edge_program,
36+
to_model_input_spec,
1937
to_quantized_edge_program,
2038
)
2139
from executorch.backends.nxp.tests.executors import (
2240
convert_run_compare,
2341
graph_contains_any,
2442
graph_contains_any_of_ops,
43+
OverrideTargetSupportCheck,
2544
ToChannelFirstPreprocess,
2645
ToChannelLastPreprocess,
2746
)
47+
from executorch.exir import EdgeCompileConfig
2848
from executorch.exir.dialects._ops import ops as exir_ops
49+
from executorch.extension.export_util.utils import export_to_edge
2950
from parameterized import parameterized
3051
from torch import nn
3152
from torch.export import ExportedProgram
@@ -76,6 +97,42 @@ def forward(self, x):
7697
return self.block(x)
7798

7899

100+
class TransposeReshapeModel(nn.Module):
101+
102+
def __init__(self, new_shape: list[int]):
103+
super().__init__()
104+
self.new_shape = new_shape
105+
106+
def forward(self, x):
107+
# `x` should be 4D.
108+
109+
x = torch.add(x, x)
110+
x = torch.permute(x, [0, 3, 1, 2])
111+
# A `clone(memory_format=contiguous)` will be added here during the lowering to edge dialect.
112+
x = torch.reshape(x, self.new_shape)
113+
114+
return x
115+
116+
117+
class PermuteCopyReshapeModel(nn.Module):
118+
119+
def __init__(self, new_shape: list[int], permutation: list[int]):
120+
super().__init__()
121+
self.new_shape = new_shape
122+
self.permutation = permutation
123+
124+
def forward(self, x):
125+
# `x` should be 4D.
126+
127+
x = torch.add(x, x)
128+
x = torch.permute(x, self.permutation)
129+
# A `clone(memory_format=contiguous)` will be added here during the lowering to edge dialect.
130+
x = torch.reshape(x, self.new_shape)
131+
x = torch.add(x, x)
132+
133+
return x
134+
135+
79136
class TestCloneConverter(unittest.TestCase):
80137
__test__ = False # Prevent interfering with PyTest tests
81138

@@ -185,3 +242,91 @@ def test_clone_pool_view_copy_quant(self, input_shape: tuple[int] = (1, 64, 25,
185242
input_data=input_data,
186243
atol=1.0,
187244
)
245+
246+
def test_clone__to_contiguous_format(self):
247+
input_shape = (1, 8, 8, 8)
248+
new_shape = [1, 32, 2, 8]
249+
250+
model = TransposeReshapeModel(new_shape).eval()
251+
252+
calibration_inputs = get_random_calibration_inputs(
253+
to_model_input_spec(input_shape)
254+
)
255+
256+
example_input = calibration_inputs[0]
257+
258+
exir_program_aten = torch.export.export(model, example_input, strict=True)
259+
260+
exir_program_aten__module_quant = post_training_quantize(
261+
exir_program_aten,
262+
calibration_inputs,
263+
NeutronQuantizer(neutron_target_spec),
264+
)
265+
266+
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
267+
edge_program_manager = export_to_edge(
268+
exir_program_aten__module_quant,
269+
example_input,
270+
edge_compile_config=edge_compile_config,
271+
)
272+
# Make sure the `aten.clone` was inserted as expected.
273+
nodes = list(edge_program_manager.exported_program().graph.nodes)
274+
assert nodes[9].target == exir_ops.edge.dim_order_ops._clone_dim_order.default
275+
assert nodes[9].kwargs["dim_order"] == [0, 1, 2, 3]
276+
277+
# Move the `clone` out of the cluster with the `view_copy`.
278+
edge_program_manager = edge_program_manager.transform(
279+
NeutronEdgePassManager(
280+
[MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass()]
281+
)
282+
)
283+
284+
# Tag QDQ clusters, so the conversion works correctly.
285+
QDQClusterRecognizer().tag_qdq_clusters(
286+
list(edge_program_manager.exported_program().graph.nodes)
287+
)
288+
edge_program_manager.exported_program().graph_module.recompile()
289+
edge_program_manager = edge_program_manager.transform(
290+
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
291+
)
292+
293+
# Identify the node formats.
294+
NodeFormatInference(
295+
edge_program_manager.exported_program()
296+
).identify_node_formats()
297+
298+
# Convert to the IR.
299+
converted_model, _ = EdgeProgramToIRConverter().convert_program(
300+
edge_program_manager.exported_program()
301+
)
302+
303+
# Make sure the IR version produces the same outputs.
304+
executors.convert_run_compare(
305+
edge_program_manager.exported_program(),
306+
np.random.random_integers(0, 255, input_shape).astype("int8"),
307+
tfl_model=converted_model,
308+
)
309+
310+
def test_clone__to_contiguous_format__non_delegated_permute_copy(self):
311+
input_shape = (2, 4, 6, 8)
312+
new_shape = [3, 4, 16, 2]
313+
permutation = [3, 2, 1, 0] # Unsupported by default.
314+
315+
model = PermuteCopyReshapeModel(new_shape, permutation).eval()
316+
317+
# Prohibit `permute_copy` delegation in case support for the permutation is added in the future.
318+
def _unsupported_target(*_):
319+
return False
320+
321+
with OverrideTargetSupportCheck(
322+
PermuteCopyConverter, new_target_support_check=_unsupported_target
323+
):
324+
ep = to_quantized_edge_program(model, input_shape).exported_program()
325+
326+
nodes = list(ep.graph.nodes)
327+
assert not graph_contains_any_of_ops(
328+
ep.graph, [exir_ops.edge.aten.clone.default]
329+
)
330+
assert nodes[3].name == "executorch_call_delegate"
331+
assert nodes[6].target == exir_ops.edge.aten.permute_copy.default
332+
assert nodes[9].name == "executorch_call_delegate_1"

0 commit comments

Comments
 (0)