Skip to content

Commit aa651f1

Browse files
committed
NXP backend: Add RemoveAdditionalQDQClustersPass.
1 parent 0413e02 commit aa651f1

File tree

7 files changed

+282
-18
lines changed

7 files changed

+282
-18
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,14 @@ def try_get_tensor_constant_from_node(
8787
return None
8888
attr_itr = getattr(attr_itr, atom)
8989
return attr_itr
90+
91+
92+
Scale = list[float] | float
93+
ZeroPoint = list[int] | int
94+
95+
96+
def get_quantization_parameters_for(node: Node) -> tuple[Scale, ZeroPoint] | None:
97+
if "quantize" not in node.target.__name__ or len(node.args) < 3:
98+
return None
99+
100+
return node.args[1], node.args[2] # Scale and zero_point
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
import torch
8+
9+
from executorch.backends.nxp.backend.edge_helper import get_quantization_parameters_for
10+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
11+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx.passes.infra.pass_base import PassResult
14+
15+
16+
class RemoveAdditionalQDQClustersPass(NeutronEdgePass):
17+
"""
18+
After delegation of partitions, there may be additional dequantize quantize nodes for QDQ clusters that were
19+
not delegated. If dequantize quantize nodes are quantized per tensor and quantization parameters of dequantize
20+
and quantize nodes in a QDQ cluster are equal, the nodes can be removed and thus the inner nodes computed in int8.
21+
22+
23+
┌────────────▼──────────┐
24+
│ dequantize_per_tensor │
25+
└────────────┬──────────┘
26+
│ │
27+
┌───▼──┐ replace with ┌───▼──┐
28+
│ node │ ──────────────► │ node │
29+
└───┬──┘ └───┬──┘
30+
│ ▼
31+
┌───────────▼─────────┐
32+
│ quantize_per_tensor │
33+
└───────────┬─────────┘
34+
35+
36+
"""
37+
38+
qdq_per_channel_nodes = (
39+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
41+
)
42+
43+
qdq_per_tensor_nodes = (
44+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
45+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
46+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
47+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
48+
)
49+
50+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
51+
nodes = list(graph_module.graph.nodes)
52+
qdq_clusterer = QDQClusterRecognizer()
53+
qdq_clusterer.tag_qdq_clusters(nodes)
54+
55+
for cluster in qdq_clusterer.cluster_map.values():
56+
# For now, enable only permute_copy and cat.
57+
if cluster.compute_node.target not in [
58+
exir_ops.edge.aten.permute_copy.default,
59+
exir_ops.edge.aten.cat.default,
60+
]:
61+
continue
62+
63+
# Ensure cluster doesn't contain dequantize/quantize per channel nodes.
64+
if any(
65+
node
66+
for node in cluster.ops
67+
if node.target in self.qdq_per_channel_nodes
68+
):
69+
continue
70+
71+
qdq_nodes = [
72+
node for node in cluster.ops if node.target in self.qdq_per_tensor_nodes
73+
]
74+
75+
qdq_nodes_quant_params = [
76+
get_quantization_parameters_for(node) for node in qdq_nodes
77+
]
78+
79+
equal_quant_scales = [
80+
np.allclose(
81+
qdq_nodes_quant_params[idx][0], qdq_nodes_quant_params[idx + 1][0]
82+
)
83+
for idx in range(len(qdq_nodes_quant_params[:-1]))
84+
]
85+
86+
equal_quant_zero_points = [
87+
np.allclose(
88+
qdq_nodes_quant_params[idx][1], qdq_nodes_quant_params[idx + 1][1]
89+
)
90+
for idx in range(len(qdq_nodes_quant_params[:-1]))
91+
]
92+
93+
# Check if all quantization params are equal to ensure that QDQ cluster can be removed.
94+
if not all(equal_quant_scales + equal_quant_zero_points):
95+
continue
96+
97+
# Replace the uses of each dequantize/quantize node with its arg node.
98+
for qdq_node in qdq_nodes:
99+
qdq_node.replace_all_uses_with(qdq_node.args[0])
100+
graph_module.graph.erase_node(qdq_node)
101+
102+
# Remove compute node cluster info from node meta.
103+
cluster.compute_node.meta.pop("cluster")
104+
105+
graph_module = self.recompile_module(graph_module)
106+
107+
# The graph has now changed, and we cannot keep iterating through it. Return the new graph and the parent
108+
# class will call this pass again.
109+
return PassResult(graph_module, True)
110+
111+
return PassResult(graph_module, False)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1616
NeutronEdgePassManager,
1717
)
18+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
19+
RemoveAdditionalQDQClustersPass,
20+
)
1821
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1922
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
2023
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -58,7 +61,6 @@ def get_random_calibration_inputs(
5861
def to_model_input_spec(
5962
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
6063
) -> tuple[ModelInputSpec, ...]:
61-
6264
if isinstance(input_spec, tuple) and all(
6365
isinstance(spec, ModelInputSpec) for spec in input_spec
6466
):
@@ -126,6 +128,10 @@ def to_quantized_edge_program(
126128
partitioner = NeutronPartitioner(compile_spec, custom_delegation_options)
127129
edge_program_manager = edge_program_manager.to_backend(partitioner)
128130

131+
edge_program_manager = NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])(
132+
edge_program_manager
133+
)
134+
129135
return edge_program_manager
130136

131137

backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(self, x):
104104
return torch.permute(x, self.perm)
105105

106106

107-
class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase):
107+
class TestPermuteCopyConversion(unittest.TestCase):
108108
@classmethod
109109
def setUpClass(cls):
110110
torch.manual_seed(23)
@@ -302,9 +302,9 @@ def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized(
302302
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
303303

304304
nodes = list(edge_program.graph.nodes)
305-
assert len(nodes) == 10
305+
assert len(nodes) == 8
306306
assert (
307-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
307+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
308308
) # PermuteCopy not delegated.
309309

310310
@parameterized.expand(
@@ -320,7 +320,7 @@ def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized(
320320
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
321321

322322
nodes = list(edge_program.graph.nodes)
323-
assert len(nodes) == 10
323+
assert len(nodes) == 8
324324
assert (
325-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
325+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
326326
) # PermuteCopy not delegated.

backends/nxp/tests/test_edge_passes.py

Lines changed: 141 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,49 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
import unittest
8+
19
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
13+
NeutronAtenPassManager,
14+
)
15+
from executorch.backends.nxp.backend.custom_delegation_options import (
16+
CustomDelegationOptions,
17+
)
218
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import (
319
ViewCopyConverter,
420
)
5-
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
21+
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
22+
NeutronEdgePassManager,
23+
)
24+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
25+
RemoveAdditionalQDQClustersPass,
26+
)
27+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
28+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
29+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
30+
from executorch.backends.nxp.tests.executorch_pipeline import (
31+
_quantize_model,
32+
get_random_calibration_inputs,
33+
to_model_input_spec,
34+
to_quantized_edge_program,
35+
)
636
from executorch.backends.nxp.tests.executors import (
37+
compare_output_arrays,
738
EdgeProgramExecutor,
839
OverrideTargetSupportCheck,
940
)
41+
from executorch.backends.nxp.tests.ir.converter.node_converter.test_permute_copy_converter import (
42+
Conv2dPermuteModule,
43+
)
1044
from executorch.backends.nxp.tests.models import ConvFCFCSoftmaxModuleWithoutReshape
1145
from executorch.exir.dialects._ops import ops as exir_ops
46+
from executorch.extension.export_util.utils import export_to_edge
1247
from torch.fx import Graph, Node
1348

1449

@@ -57,18 +92,26 @@ def _assert_nodes_form_a_view_copy_qdq_cluster(graph: Graph, node_indices: list[
5792
assert quantize.args[0] == view_copy
5893

5994

60-
def test_moving_view_copy_into_separate_qdq_clusters():
61-
model = ConvFCFCSoftmaxModuleWithoutReshape()
62-
input_shape = (1, 4, 3, 33)
95+
class TestEdgePasses(unittest.TestCase):
96+
@classmethod
97+
def setUpClass(cls):
98+
torch.manual_seed(23)
99+
np.random.seed(42)
63100

64-
# Prohibit `view_copy` conversion for the testing purposes.
65-
def unsupported_target(*_):
66-
return False
101+
def test_moving_view_copy_into_separate_qdq_clusters(self):
102+
model = ConvFCFCSoftmaxModuleWithoutReshape()
103+
input_shape = (1, 4, 3, 33)
104+
105+
# Prohibit `view_copy` conversion for the testing purposes.
106+
def unsupported_target(*_):
107+
return False
108+
109+
# Prohibit `view_copy` conversion for the testing purposes.
110+
with OverrideTargetSupportCheck(
111+
ViewCopyConverter, new_target_support_check=unsupported_target
112+
):
113+
epm = to_quantized_edge_program(model, input_shape, target="imxrt700")
67114

68-
with OverrideTargetSupportCheck(
69-
ViewCopyConverter, new_target_support_check=unsupported_target
70-
):
71-
epm = to_quantized_edge_program(model, input_shape, target="imxrt700")
72115
exported_program = epm.exported_program()
73116

74117
nodes = list(exported_program.graph_module.graph.nodes)
@@ -86,3 +129,90 @@ def unsupported_target(*_):
86129
input_data = np.random.random(input_shape).astype("float32")
87130
program_executor = EdgeProgramExecutor(exported_program)
88131
program_executor.inference(input_data)
132+
133+
def test_remove_additional_quantize_dequantize_nodes_pass(self):
134+
input_shape = (1, 3, 8, 16)
135+
new_dims = (3, 2, 1, 0)
136+
model = Conv2dPermuteModule(input_shape[1], new_dims)
137+
target = "imxrt700"
138+
custom_delegation_options = CustomDelegationOptions()
139+
140+
calibration_inputs = get_random_calibration_inputs(
141+
to_model_input_spec(input_shape)
142+
)
143+
144+
example_input = calibration_inputs[0]
145+
exir_program_aten = torch.export.export(model, example_input).module()
146+
147+
# Run pre-processing passes of the float32 aten dialect program.
148+
exir_program_aten = NeutronAtenPassManager()(exir_program_aten).graph_module
149+
150+
exir_program_aten_quant = _quantize_model(
151+
exir_program_aten, NeutronQuantizer(), calibration_inputs
152+
)
153+
edge_program_manager = export_to_edge(
154+
exir_program_aten_quant,
155+
example_input,
156+
)
157+
158+
edge_program_manager = NeutronEdgePassManager()(edge_program_manager)
159+
160+
compile_spec = generate_neutron_compile_spec(target, "SDK_25_09")
161+
partitioner = NeutronPartitioner(compile_spec, custom_delegation_options)
162+
163+
edge_program_manager = edge_program_manager.to_backend(partitioner)
164+
165+
# Make sure QDQ cluster for permute_copy is present.
166+
edge_program_with_qdq_cluster = copy.deepcopy(
167+
edge_program_manager.exported_program()
168+
)
169+
nodes = list(edge_program_with_qdq_cluster.graph.nodes)
170+
assert len(nodes) == 10
171+
assert (
172+
nodes[5].target
173+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
174+
)
175+
assert nodes[6].target == exir_ops.edge.aten.permute_copy.default
176+
assert "cluster" in nodes[6].meta
177+
assert (
178+
nodes[7].target
179+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
180+
)
181+
182+
# Run pass for removal of additional QDQ nodes and compute in non-float types where possible
183+
edge_program_manager = NeutronEdgePassManager(
184+
[RemoveAdditionalQDQClustersPass()]
185+
)(edge_program_manager)
186+
187+
# Make sure QDQ cluster for permute_copy is removed.
188+
edge_program_without_qdq_cluster = edge_program_manager.exported_program()
189+
nodes = list(edge_program_without_qdq_cluster.graph.nodes)
190+
assert len(nodes) == 8
191+
assert nodes[4].name == "getitem"
192+
assert nodes[5].target == exir_ops.edge.aten.permute_copy.default
193+
assert "cluster" not in nodes[5].meta
194+
assert (
195+
nodes[6].target
196+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
197+
)
198+
199+
edge_program_executor_without_qdq_cluster = EdgeProgramExecutor(
200+
edge_program_without_qdq_cluster
201+
)
202+
edge_program_executor_with_qdq_cluster = EdgeProgramExecutor(
203+
edge_program_with_qdq_cluster
204+
)
205+
206+
input_data = np.random.random(input_shape).astype(np.float32)
207+
edge_program_output_without_qdq_cluster = (
208+
edge_program_executor_without_qdq_cluster.inference(input_data)
209+
)
210+
edge_program_output_with_qdq_cluster = (
211+
edge_program_executor_with_qdq_cluster.inference(input_data)
212+
)
213+
214+
compare_output_arrays(
215+
edge_program_output_without_qdq_cluster,
216+
edge_program_output_with_qdq_cluster,
217+
"main output",
218+
)

backends/nxp/tests/test_turning_batch_first_gru_to_time_major.py

Whitespace-only changes.

examples/nxp/aot_neutron_compile.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
from collections import defaultdict
1212
from typing import Iterator
1313

14-
import executorch.extension.pybindings.portable_lib
1514
import executorch.kernels.quantized # noqa F401
1615

1716
import torch
1817
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1918
NeutronEdgePassManager,
2019
)
20+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
21+
RemoveAdditionalQDQClustersPass,
22+
)
2123
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
2224
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
2325
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -272,6 +274,10 @@ def _get_batch_size(data):
272274
remove_io_quant_ops=args.remove_quant_io_ops
273275
)(edge_program_manager)
274276

277+
edge_program_manager = NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])(
278+
edge_program_manager
279+
)
280+
275281
logging.debug(f"Lowered graph:\n{edge_program_manager.exported_program().graph}")
276282

277283
# 5. Export to ExecuTorch program

0 commit comments

Comments
 (0)