Skip to content

Commit 2c82054

Browse files
skywallStrycekSimon
authored andcommitted
NXP backend: Add support for per-channel quantization for Conv
1 parent e7c43b6 commit 2c82054

File tree

8 files changed

+211
-16
lines changed

8 files changed

+211
-16
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex
134134

135135
qdq_related_functions = [
136136
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
137+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
137138
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
138139
]
139140

@@ -203,7 +204,8 @@ def _convert_qdq_cluster_q_dq_nodes(
203204
:param conversion_context: ConversionContext instance.
204205
"""
205206
qdq_q_ops_converters = {
206-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, # noqa F405
207+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQPerTensorDequantizeConverter, # noqa F405
208+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: QDQPerChannelDequantizeConverter, # noqa F405
207209
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, # noqa F405
208210
}
209211

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
PermuteCopyConverter,
4242
)
4343
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_dequantize_converter import (
44-
QDQDequantizeConverter,
44+
QDQPerChannelDequantizeConverter,
45+
QDQPerTensorDequantizeConverter,
4546
)
4647
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_quantize_converter import (
4748
QDQQuantizeConverter,
@@ -70,7 +71,8 @@
7071
"PermuteCopyConverter",
7172
"SoftmaxConverter",
7273
"ViewCopyConverter",
73-
"QDQDequantizeConverter",
74+
"QDQPerTensorDequantizeConverter",
75+
"QDQPerChannelDequantizeConverter",
7476
"QDQQuantizeConverter",
7577
"ConstantPadNDConverter",
7678
"ReLUConverter",

backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from abc import ABC, abstractmethod
56

67
import numpy as np
78

@@ -19,15 +20,23 @@
1920
from torch.nn import Parameter
2021

2122

22-
class QDQDequantizeConverter(NodeConverter):
23+
class QDQDequantizeConverterBase(NodeConverter, ABC):
24+
25+
@abstractmethod
26+
def get_zero_point(self, node: Node) -> np.ndarray:
27+
pass
28+
29+
@abstractmethod
30+
def get_scale(self, node: Node) -> np.ndarray:
31+
pass
2332

2433
@staticmethod
2534
def _is_supported_in_IR(
2635
node: Node,
2736
parameters_mapping: dict[str, Parameter],
2837
custom_delegation_options: CustomDelegationOptions,
2938
) -> bool:
30-
zero_point_type = torch_type_to_numpy_type(node.args[5])
39+
zero_point_type = torch_type_to_numpy_type(node.args[-1])
3140
if "cluster" not in node.meta or zero_point_type not in [np.int8, np.int32]:
3241
return False
3342

@@ -39,10 +48,8 @@ def convert(self, node: Node):
3948
from_tensor = self.builder.tensor_for_name(node.name)
4049
to_tensor = self.builder.tensor_for_name(node.args[0].name)
4150

42-
zero_point_type = torch_type_to_numpy_type(node.args[5])
43-
44-
scale = np.array(node.args[1], dtype=np.float32)
45-
zero_point = np.array(node.args[2], dtype=zero_point_type)
51+
scale = self.get_scale(node)
52+
zero_point = self.get_zero_point(node)
4653

4754
if self.context.parameters_mapping.get(node.args[0].name, None) is None:
4855
# Convert dequantize as identity op (Transpose that will be removed) because
@@ -63,3 +70,22 @@ def convert(self, node: Node):
6370
# Change type so we pass check tensor similarity check when redirecting
6471
from_tensor.type = to_tensor.type
6572
self.builder.redirect_tensor(from_tensor, to_tensor)
73+
74+
75+
class QDQPerTensorDequantizeConverter(QDQDequantizeConverterBase):
76+
77+
def get_zero_point(self, node: Node) -> np.ndarray:
78+
zero_point_type = torch_type_to_numpy_type(node.args[5])
79+
return np.array(node.args[2], dtype=zero_point_type)
80+
81+
def get_scale(self, node: Node) -> np.ndarray:
82+
return np.array(node.args[1], dtype=np.float32)
83+
84+
85+
class QDQPerChannelDequantizeConverter(QDQDequantizeConverterBase):
86+
87+
def get_zero_point(self, node: Node) -> np.ndarray:
88+
return self.context.parameters_mapping[node.args[2].name].numpy()
89+
90+
def get_scale(self, node: Node) -> np.ndarray:
91+
return self.context.parameters_mapping[node.args[1].name].numpy()

backends/nxp/quantizer/patterns.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchao.quantization.pt2e.quantizer import (
1717
DerivedQuantizationSpec,
1818
FixedQParamsQuantizationSpec,
19+
QuantizationSpec,
1920
SharedQuantizationSpec,
2021
)
2122
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
@@ -54,7 +55,9 @@ class PartitionAnchors:
5455
tuple[fx.Node, NodeArgsIdx]
5556
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
5657
] = field(default_factory=list)
57-
weights: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
58+
weights: list[
59+
tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec],
60+
] = field(default_factory=list)
5861
biases: list[
5962
tuple[fx.Node, NodeArgsIdx]
6063
| tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec],

backends/nxp/tests/executorch_pipeline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ class ModelInputSpec:
3838
dtype: torch.dtype = torch.float32
3939

4040

41-
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]):
42-
quantizer = NeutronQuantizer()
43-
41+
def _quantize_model(
42+
model, quantizer, calibration_inputs: list[tuple[torch.Tensor, ...]]
43+
):
4444
m = prepare_pt2e(model, quantizer)
4545
for data in calibration_inputs:
4646
m(*data)
@@ -91,6 +91,7 @@ def to_quantized_edge_program(
9191
neutron_converter_flavor="SDK_25_06",
9292
remove_quant_io_ops=False,
9393
custom_delegation_options=CustomDelegationOptions(), # noqa B008
94+
get_quantizer_fn=lambda: NeutronQuantizer(),
9495
) -> EdgeProgramManager:
9596
calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec))
9697

@@ -102,7 +103,9 @@ def to_quantized_edge_program(
102103
exir_program_aten = torch.export.export(model, example_input, strict=True)
103104

104105
exir_program_aten__module_quant = _quantize_model(
105-
exir_program_aten.module(), calibration_inputs
106+
exir_program_aten.module(),
107+
get_quantizer_fn(),
108+
calibration_inputs,
106109
)
107110

108111
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import kgb
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.backend.edge_program_converter import (
13+
EdgeProgramToIRConverter,
14+
)
15+
from executorch.backends.nxp.quantizer.neutron_quantizer import (
16+
act_qspec,
17+
NeutronAtenQuantizer,
18+
wgt_qspec,
19+
)
20+
from executorch.backends.nxp.quantizer.patterns import (
21+
NodeArgsIdx,
22+
PartitionAnchors,
23+
QuantizationPattern,
24+
)
25+
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
26+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
27+
from executorch.backends.nxp.tests.executors import (
28+
convert_run_compare,
29+
ToChannelFirstPreprocess,
30+
ToChannelLastPreprocess,
31+
)
32+
from executorch.backends.nxp.tests.models import Conv2dModule
33+
from executorch.backends.nxp.tests.test_quantizer import _get_target_name
34+
35+
from torch import fx
36+
from torch._ops import OpOverload
37+
from torch.export import ExportedProgram
38+
from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver
39+
from torchao.quantization.pt2e.quantizer import (
40+
DerivedQuantizationSpec,
41+
QuantizationConfig,
42+
QuantizationSpec,
43+
)
44+
45+
46+
class Conv2dPatternPerChannel(QuantizationPattern):
47+
48+
def __init__(self, is_per_channel: bool):
49+
super().__init__()
50+
self.is_per_channel = is_per_channel
51+
52+
def partition_types(self) -> list[OpOverload]:
53+
return [torch.ops.aten.conv2d.default]
54+
55+
def get_anchors(
56+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
57+
) -> PartitionAnchors:
58+
conv2d_node = fused_partition[0].nodes[-1]
59+
60+
bias_qscheme = (
61+
torch.per_channel_symmetric
62+
if self.is_per_channel
63+
else torch.per_tensor_symmetric
64+
)
65+
bias_quantization_qspec = DerivedQuantizationSpec(
66+
derived_from=[
67+
(conv2d_node.args[0], conv2d_node),
68+
(conv2d_node.args[1], conv2d_node),
69+
],
70+
derive_qparams_fn=get_bias_qparams,
71+
dtype=torch.int32,
72+
quant_min=-(2**31) + 1,
73+
quant_max=2**31 - 1,
74+
qscheme=bias_qscheme,
75+
ch_axis=0,
76+
)
77+
78+
weight_qscheme = (
79+
torch.per_channel_symmetric
80+
if self.is_per_channel
81+
else torch.per_tensor_symmetric
82+
)
83+
weight_observer_or_fake_quant_ctr = (
84+
PerChannelMinMaxObserver if self.is_per_channel else MinMaxObserver
85+
)
86+
weight_quantization_spec = QuantizationSpec(
87+
dtype=torch.int8,
88+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
89+
quant_min=-127,
90+
quant_max=127,
91+
qscheme=weight_qscheme,
92+
ch_axis=0,
93+
)
94+
95+
return PartitionAnchors(
96+
inputs=[(conv2d_node, NodeArgsIdx(0))],
97+
weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)],
98+
biases=[(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)],
99+
output=[(conv2d_node,)],
100+
)
101+
102+
103+
class TestPerChannelConversion(unittest.TestCase):
104+
__test__ = False # Prevent interfering with PyTest tests
105+
106+
def test_per_channel_convolution(self):
107+
with kgb.spy_on(
108+
EdgeProgramToIRConverter.convert_program, call_original=True
109+
) as converter_spy:
110+
model = Conv2dModule(
111+
in_channels=8, out_channels=32, kernel_size=5, padding=3
112+
)
113+
input_shape = (1, 8, 32, 32)
114+
115+
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
116+
_ = to_quantized_edge_program(
117+
model,
118+
input_shape,
119+
get_quantizer_fn=lambda: NeutronAtenQuantizer(
120+
Conv2dPatternPerChannel(is_per_channel=True), static_qconfig
121+
),
122+
)
123+
124+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
125+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
126+
127+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
128+
np.int8
129+
)
130+
131+
convert_run_compare(
132+
exported_program,
133+
tflite_input_preprocess=ToChannelLastPreprocess(),
134+
tfl_model=tflite_flatbuffers_model,
135+
tflite_output_preprocess=ToChannelFirstPreprocess(),
136+
input_data=input_data,
137+
atol=1.0,
138+
)
139+
140+
nodes = list(exported_program.graph.nodes)
141+
142+
assert _get_target_name(nodes[8]).endswith(
143+
"quantized_decomposed.dequantize_per_channel.default"
144+
)
145+
assert _get_target_name(nodes[9]).endswith(
146+
"quantized_decomposed.dequantize_per_channel.default"
147+
)
148+
assert nodes[10].name == "aten_convolution_default"
149+
150+
@classmethod
151+
def setUpClass(cls):
152+
torch.manual_seed(25)
153+
np.random.seed(25)

backends/nxp/tests/test_removing_dead_code.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111

12+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
1213
from executorch.backends.nxp.tests.executorch_pipeline import _quantize_model
1314
from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops
1415

@@ -45,8 +46,9 @@ def test_removing_dead_code(self):
4546
)
4647

4748
# The `NeutronQuantizer` should remove the dead code in the `transform_for_annotation()` method.
49+
quantizer = NeutronQuantizer()
4850
exir_program_aten_quant = _quantize_model(
49-
exir_program_aten.module(), [example_inputs]
51+
exir_program_aten.module(), quantizer, [example_inputs]
5052
)
5153

5254
# Make sure the is no `add` operation in the graph anymore.

backends/nxp/tests/test_split_group_convolution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1919
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
20+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
2021
from executorch.backends.nxp.tests.executorch_pipeline import (
2122
_quantize_model,
2223
get_random_calibration_inputs,
@@ -39,8 +40,11 @@ def _quantize_and_lower_module(
3940
module: GraphModule, input_shape: tuple[int, ...], target="imxrt700"
4041
) -> EdgeProgramManager:
4142
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape))
43+
quantizer = NeutronQuantizer()
4244

43-
exir_program_aten__module_quant = _quantize_model(module, calibration_inputs)
45+
exir_program_aten__module_quant = _quantize_model(
46+
module, quantizer, calibration_inputs
47+
)
4448

4549
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
4650
edge_program_manager = export_to_edge(

0 commit comments

Comments
 (0)