Skip to content

Commit b340ad1

Browse files
skywallStrycekSimon
authored andcommitted
NXP backend: Use per-channel quantization for Conv in NeutronQuantizer
1 parent 2c82054 commit b340ad1

File tree

9 files changed

+66
-50
lines changed

9 files changed

+66
-50
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ def _convert_2d_conv(
321321
t_op.tmp_inputs[1] = self.builder.create_transposed_tensor(
322322
weight_tensor, perm
323323
)
324+
325+
if t_op.tmp_inputs[1].quantization is not None:
326+
# Model is quantized
327+
t_op.tmp_inputs[1].quantization.quantized_dimension = 3
324328
else:
325329
raise NotImplementedError("Dynamic Depthwise Conv weights.")
326330

backends/nxp/quantizer/patterns.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
1414
from torch import fx
1515
from torch._ops import OpOverload
16+
from torchao.quantization.pt2e import PerChannelMinMaxObserver
1617
from torchao.quantization.pt2e.quantizer import (
1718
DerivedQuantizationSpec,
1819
FixedQParamsQuantizationSpec,
@@ -318,30 +319,39 @@ def partition_types(self) -> list[OpOverload]:
318319
def get_anchors(
319320
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
320321
) -> PartitionAnchors:
321-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
322322
conv2d_node = fused_partition[0].nodes[-1]
323323

324-
bias_qspec = DerivedQuantizationSpec(
324+
bias_quantization_qspec = DerivedQuantizationSpec(
325325
derived_from=[
326326
(conv2d_node.args[0], conv2d_node),
327327
(conv2d_node.args[1], conv2d_node),
328328
],
329329
derive_qparams_fn=get_bias_qparams,
330330
dtype=torch.int32,
331-
quant_min=-(2**31),
331+
quant_min=-(2**31) + 1,
332332
quant_max=2**31 - 1,
333-
qscheme=torch.per_tensor_affine,
333+
qscheme=torch.per_channel_symmetric,
334+
ch_axis=0,
335+
)
336+
337+
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
338+
weight_quantization_spec = QuantizationSpec(
339+
dtype=torch.int8,
340+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
341+
quant_min=-127,
342+
quant_max=127,
343+
qscheme=torch.per_channel_symmetric,
344+
ch_axis=0,
334345
)
335346

336347
# Keep bias empty if not supplied
337348
bias = []
338349
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
339-
bias = [(conv2d_node, NodeArgsIdx(2), bias_qspec)]
350+
bias = [(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)]
340351

341352
return PartitionAnchors(
342353
inputs=[(conv2d_node, NodeArgsIdx(0))],
343-
weights=[(conv2d_node, NodeArgsIdx(1))],
344-
# pyre-fixme[6]: Incompatible parameter type
354+
weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)],
345355
biases=bias,
346356
output=[(conv2d_node,)],
347357
)

backends/nxp/quantizer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_bias_qparams(
4949
act_scale, _ = obs_or_fqs[0].calculate_qparams()
5050
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
5151
bias_scale = act_scale * weight_scale
52-
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
52+
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64)
5353
return bias_scale, bias_zero_point
5454

5555

backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool):
5757
tflite_input_preprocess=ToNHWCPreprocess(),
5858
tflite_output_preprocess=ToNCHWPreprocess(),
5959
input_data=input_data,
60-
atol=1.0,
60+
atol=2.0,
6161
)
6262

6363

backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True)
4949
input_data=input_data,
5050
tflite_output_preprocess=ToChannelFirstPreprocess(),
5151
tfl_model=tflite_flatbuffers_model,
52+
atol=1.0,
5253
)
5354

5455

backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_conv_tanh(
7676
tflite_input_preprocess=ToChannelLastPreprocess(),
7777
tflite_output_preprocess=ToChannelFirstPreprocess(),
7878
input_data=input_data,
79-
atol=1.0,
79+
atol=2.0,
8080
)
8181

8282
@classmethod

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool):
168168
nodes = list(edge_program.graph.nodes)
169169

170170
assert (
171-
len(nodes) == 13
171+
len(nodes) == 17
172172
) # 1D Conv currently isn't delegated, because it doesn't get quantized.
173173
assert not any(
174174
node.op == "call_function" and "batch_norm" in node.target.__name__

backends/nxp/tests/test_qdq_clustering_conv.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def test_conv2d_partitioner():
1616
lowered_module = edge_program.exported_program().graph_module.lowered_module_0
1717
nodes = list(lowered_module.original_module.graph.nodes)
1818

19-
assert len(nodes) == 7
19+
assert len(nodes) == 9
2020

21-
q_x_node = nodes[1]
22-
dq_w_node = nodes[2]
23-
dq_x_node = nodes[3]
24-
conv_node = nodes[4]
25-
q_y_node = nodes[5]
21+
q_x_node = nodes[3]
22+
dq_w_node = nodes[4]
23+
dq_x_node = nodes[5]
24+
conv_node = nodes[6]
25+
q_y_node = nodes[7]
2626

2727
assert "cluster" not in q_x_node.meta
2828
assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster"

backends/nxp/tests/test_quantizer.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -34,26 +34,26 @@ def test_quantizer_conv2d():
3434
m(*example_input)
3535

3636
nodes = list(m.graph.nodes)
37-
assert len(nodes) == 11
38-
assert nodes[7].name == "conv2d"
37+
assert len(nodes) == 15
38+
assert nodes[11].name == "conv2d"
3939
# [0]: Input, [1] : weights, [2]: bias
4040
assert (
41-
_get_target_name(nodes[7].args[0])
41+
_get_target_name(nodes[11].args[0])
4242
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
4343
)
4444
assert (
45-
_get_target_name(nodes[7].args[1])
46-
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
45+
_get_target_name(nodes[11].args[1])
46+
== "torch.ops.quantized_decomposed.dequantize_per_channel.default"
4747
)
4848
assert (
49-
_get_target_name(nodes[7].args[2])
50-
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
49+
_get_target_name(nodes[11].args[2])
50+
== "torch.ops.quantized_decomposed.dequantize_per_channel.default"
5151
)
5252
assert (
53-
_get_target_name(nodes[8])
53+
_get_target_name(nodes[12])
5454
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
5555
)
56-
assert nodes[8].args[0].name == "conv2d"
56+
assert nodes[12].args[0].name == "conv2d"
5757

5858

5959
def test_quantizer_linear():
@@ -112,22 +112,22 @@ def test_quantizer_maxpool2d():
112112
m(*example_input)
113113

114114
nodes = list(m.graph.nodes)
115-
assert len(nodes) == 14
115+
assert len(nodes) == 18
116116
# Check if QDQ pattern:
117-
assert nodes[10].name == "max_pool2d"
117+
assert nodes[14].name == "max_pool2d"
118118
assert (
119-
_get_target_name(nodes[10].args[0])
119+
_get_target_name(nodes[14].args[0])
120120
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
121121
)
122122
assert (
123-
_get_target_name(nodes[11])
123+
_get_target_name(nodes[15])
124124
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
125125
)
126-
assert nodes[11].args[0].name == "max_pool2d"
126+
assert nodes[15].args[0].name == "max_pool2d"
127127

128128
# Check if input and output quantization is same
129-
input_quant = nodes[10].args[0].args[1:]
130-
output_quant = nodes[11].args[1:]
129+
input_quant = nodes[14].args[0].args[1:]
130+
output_quant = nodes[15].args[1:]
131131
assert input_quant == output_quant
132132

133133

@@ -207,10 +207,10 @@ def test_quantizer_conv2d_relu():
207207
m(*example_input)
208208

209209
nodes = list(m.graph.nodes)
210-
assert len(nodes) == 12
211-
assert nodes[7].name == "dequantize_per_tensor_default_2"
212-
assert nodes[8].name == "relu"
213-
assert nodes[9].name == "quantize_per_tensor_default_3"
210+
assert len(nodes) == 14
211+
assert nodes[9].name == "dequantize_per_tensor_default_1"
212+
assert nodes[10].name == "relu"
213+
assert nodes[11].name == "quantize_per_tensor_default_2"
214214

215215

216216
def test_quantizer_conv2d_avg_pool2d():
@@ -230,10 +230,10 @@ def test_quantizer_conv2d_avg_pool2d():
230230
m(*example_input)
231231

232232
nodes = list(m.graph.nodes)
233-
assert len(nodes) == 14
234-
assert nodes[9].name == "dequantize_per_tensor_default_3"
235-
assert nodes[10].name == "avg_pool2d"
236-
assert nodes[11].name == "quantize_per_tensor_default_4"
233+
assert len(nodes) == 18
234+
assert nodes[13].name == "dequantize_per_tensor_default_1"
235+
assert nodes[14].name == "avg_pool2d"
236+
assert nodes[15].name == "quantize_per_tensor_default_2"
237237

238238

239239
def test_quantizer_conv2d_permute():
@@ -253,10 +253,11 @@ def test_quantizer_conv2d_permute():
253253
m(*example_input)
254254

255255
nodes = list(m.graph.nodes)
256-
assert len(nodes) == 12
257-
assert nodes[7].name == "dequantize_per_tensor_default_2"
258-
assert nodes[8].name == "permute"
259-
assert nodes[9].name == "quantize_per_tensor_default_3"
256+
257+
assert len(nodes) == 14
258+
assert nodes[9].name == "dequantize_per_tensor_default_1"
259+
assert nodes[10].name == "permute"
260+
assert nodes[11].name == "quantize_per_tensor_default_2"
260261

261262

262263
def test_multiple_shared_spec_ops_in_row():
@@ -281,15 +282,15 @@ def test_multiple_shared_spec_ops_in_row():
281282

282283
nodes = list(m.graph.nodes)
283284

284-
assert len(nodes) == 15
285-
assert nodes[-5].name == "dequantize_per_tensor_default_3"
285+
assert len(nodes) == 17
286+
assert nodes[-5].name.startswith("dequantize_per_tensor_default")
286287
assert nodes[-4].name == "max_pool2d"
287-
assert nodes[-3].name == "quantize_per_tensor_default_4"
288+
assert nodes[-3].name.startswith("quantize_per_tensor_default")
288289

289290
# Assert that post-ReLU quantize and pre-MaxPool dequantize has same specs
290291
assert nodes[-6].args[1:] == nodes[-5].args[1:]
291292
# Assert that post-Conv quantize and pre-ReLU dequantize has same specs
292-
assert nodes[6].args[1:] == nodes[7].args[1:]
293+
assert nodes[5].args[1:] == nodes[6].args[1:]
293294

294295

295296
def test_quantizers_order_invariance():

0 commit comments

Comments
 (0)