Skip to content

Commit ffd27ca

Browse files
NXP backend: Extend tests for linear addmm mm converters add mm quantization (#14601)
### Summary This PR refactors and extends tests for Addmm and Mm converters. Adds quantization for Mm operator. ### Test plan Unit tests provided. cc @digantdesai @JakeStevens @robert-kalmar
1 parent b265324 commit ffd27ca

File tree

6 files changed

+260
-112
lines changed

6 files changed

+260
-112
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
LinearPattern,
2626
MaxPoolPattern,
2727
MeanDimPattern,
28+
MmPattern,
2829
NodeArgsIdx,
2930
PadPattern,
3031
PermutePattern,
@@ -199,6 +200,7 @@ def __init__(self):
199200
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
200201
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
201202
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
203+
NeutronAtenQuantizer(MmPattern(), static_qconfig),
202204
NeutronAtenQuantizer(PadPattern(), static_qconfig),
203205
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
204206
NeutronAtenQuantizer(ReluPattern(), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -276,55 +276,20 @@ def get_anchors(
276276
)
277277

278278

279-
class Conv1dPattern(QuantizationPattern):
280-
def partition_types(self) -> list[OpOverload]:
281-
return [torch.ops.aten.conv1d.default]
282-
283-
def get_anchors(
284-
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
285-
) -> PartitionAnchors:
286-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
287-
conv1d_node = fused_partition[0].nodes[-1]
288-
289-
bias_qspec = DerivedQuantizationSpec(
290-
derived_from=[
291-
(conv1d_node.args[0], conv1d_node),
292-
(conv1d_node.args[1], conv1d_node),
293-
],
294-
derive_qparams_fn=get_bias_qparams,
295-
dtype=torch.int32,
296-
quant_min=-(2**31),
297-
quant_max=2**31 - 1,
298-
qscheme=torch.per_tensor_affine,
299-
)
300-
301-
# Keep bias empty if not supplied
302-
bias = []
303-
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
304-
bias = [(conv1d_node, NodeArgsIdx(2), bias_qspec)]
305-
306-
return PartitionAnchors(
307-
inputs=[(conv1d_node, NodeArgsIdx(0))],
308-
weights=[(conv1d_node, NodeArgsIdx(1))],
309-
# pyre-fixme[6]: Incompatible parameter type
310-
biases=bias,
311-
output=[(conv1d_node,)],
312-
)
313-
314-
315-
class Conv2dPattern(QuantizationPattern):
279+
class ConvPattern(QuantizationPattern):
280+
@abstractmethod
316281
def partition_types(self) -> list[OpOverload]:
317-
return [torch.ops.aten.conv2d.default]
282+
pass
318283

319284
def get_anchors(
320285
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
321286
) -> PartitionAnchors:
322-
conv2d_node = fused_partition[0].nodes[-1]
287+
conv_node = fused_partition[0].nodes[-1]
323288

324289
bias_quantization_qspec = DerivedQuantizationSpec(
325290
derived_from=[
326-
(conv2d_node.args[0], conv2d_node),
327-
(conv2d_node.args[1], conv2d_node),
291+
(conv_node.args[0], conv_node),
292+
(conv_node.args[1], conv_node),
328293
],
329294
derive_qparams_fn=get_bias_qparams,
330295
dtype=torch.int32,
@@ -346,17 +311,27 @@ def get_anchors(
346311

347312
# Keep bias empty if not supplied
348313
bias = []
349-
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
350-
bias = [(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)]
314+
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
315+
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]
351316

352317
return PartitionAnchors(
353-
inputs=[(conv2d_node, NodeArgsIdx(0))],
354-
weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)],
318+
inputs=[(conv_node, NodeArgsIdx(0))],
319+
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
355320
biases=bias,
356-
output=[(conv2d_node,)],
321+
output=[(conv_node,)],
357322
)
358323

359324

325+
class Conv1dPattern(ConvPattern):
326+
def partition_types(self) -> list[OpOverload]:
327+
return [torch.ops.aten.conv1d.default]
328+
329+
330+
class Conv2dPattern(ConvPattern):
331+
def partition_types(self) -> list[OpOverload]:
332+
return [torch.ops.aten.conv2d.default]
333+
334+
360335
class DropoutPattern(SharedSpecPattern):
361336
"""
362337
Quantizer for Dropout operator.
@@ -432,7 +407,6 @@ def partition_types(self) -> list[OpOverload]:
432407
def get_anchors(
433408
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
434409
) -> PartitionAnchors:
435-
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
436410
linear_node = fused_partition[0].nodes[-1]
437411

438412
bias_qspec = DerivedQuantizationSpec(
@@ -455,7 +429,6 @@ def get_anchors(
455429
return PartitionAnchors(
456430
inputs=[(linear_node, NodeArgsIdx(0))],
457431
weights=[(linear_node, NodeArgsIdx(1))],
458-
# pyre-fixme[6]: Incompatible parameter type
459432
biases=bias,
460433
output=[(linear_node,)],
461434
)
@@ -479,6 +452,23 @@ def partition_types(self):
479452
return [torch.ops.aten.mean.dim]
480453

481454

455+
class MmPattern(QuantizationPattern):
456+
def partition_types(self) -> list[OpOverload]:
457+
return [torch.ops.aten.mm.default]
458+
459+
def get_anchors(
460+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
461+
) -> PartitionAnchors:
462+
mm_node = fused_partition[0].nodes[-1]
463+
464+
return PartitionAnchors(
465+
inputs=[(mm_node, NodeArgsIdx(0))],
466+
weights=[(mm_node, NodeArgsIdx(1))],
467+
biases=[],
468+
output=[(mm_node,)],
469+
)
470+
471+
482472
class PadPattern(SharedSpecPattern):
483473
"""
484474
Quantizer for Pad operator.
@@ -552,33 +542,33 @@ def get_anchors(
552542
)
553543

554544

555-
class TanhPattern(QuantizationPattern):
545+
class SigmoidPattern(QuantizationPattern):
556546
"""
557-
Quantizer for Tanh operator.
547+
Quantizer for Sigmoid operator.
558548
559-
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
549+
The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8.
560550
"""
561551

562-
def partition_types(self):
563-
return [torch.ops.aten.tanh.default]
552+
def partition_types(self) -> list[OpOverload]:
553+
return [torch.ops.aten.sigmoid.default]
564554

565555
def get_anchors(
566556
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
567557
) -> PartitionAnchors:
568558
return get_anchors_for_fixed_quant_specs(
569-
fused_partition, scale=1.0 / 128.0, zero_point=0
559+
fused_partition, scale=1.0 / 256.0, zero_point=-128
570560
)
571561

572562

573-
class TanhInPlacePattern(QuantizationPattern):
563+
class TanhPattern(QuantizationPattern):
574564
"""
575-
Quantizer for inplace version of Tanh operator (torch.tanh_).
565+
Quantizer for Tanh operator.
576566
577567
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
578568
"""
579569

580570
def partition_types(self):
581-
return [torch.ops.aten.tanh_.default]
571+
return [torch.ops.aten.tanh.default]
582572

583573
def get_anchors(
584574
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
@@ -588,19 +578,19 @@ def get_anchors(
588578
)
589579

590580

591-
class SigmoidPattern(QuantizationPattern):
581+
class TanhInPlacePattern(QuantizationPattern):
592582
"""
593-
Quantizer for Sigmoid operator.
583+
Quantizer for inplace version of Tanh operator (torch.tanh_).
594584
595-
The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8.
585+
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
596586
"""
597587

598-
def partition_types(self) -> list[OpOverload]:
599-
return [torch.ops.aten.sigmoid.default]
588+
def partition_types(self):
589+
return [torch.ops.aten.tanh_.default]
600590

601591
def get_anchors(
602592
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
603593
) -> PartitionAnchors:
604594
return get_anchors_for_fixed_quant_specs(
605-
fused_partition, scale=1.0 / 256.0, zero_point=-128
595+
fused_partition, scale=1.0 / 128.0, zero_point=0
606596
)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import kgb
9+
import numpy as np
10+
import torch
11+
12+
from executorch.backends.nxp.backend.edge_program_converter import (
13+
EdgeProgramToIRConverter,
14+
)
15+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
16+
from executorch.backends.nxp.tests.executors import (
17+
convert_run_compare,
18+
graph_contains_any_of_ops,
19+
)
20+
from executorch.backends.nxp.tests.models import AddmmModule, LinearModule
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from torch.export import ExportedProgram
23+
24+
25+
class TestAddmmConversion(unittest.TestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
torch.manual_seed(23)
29+
np.random.seed(42)
30+
31+
def test_addmm_conversion(self):
32+
with kgb.spy_on(
33+
EdgeProgramToIRConverter.convert_program, call_original=True
34+
) as converter_spy:
35+
input_shape = (1, 32)
36+
model = AddmmModule(input_shape[1])
37+
38+
edge_program = to_quantized_edge_program(
39+
model, input_shape
40+
).exported_program()
41+
42+
# Make sure that all nodes were delegated.
43+
assert not graph_contains_any_of_ops(
44+
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
45+
)
46+
assert any(
47+
"lowered_module" in node.name for node in edge_program.graph.nodes
48+
)
49+
50+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
51+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
52+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
53+
np.int8
54+
)
55+
convert_run_compare(
56+
exported_program,
57+
input_data,
58+
tfl_model=tflite_flatbuffers_model,
59+
)
60+
61+
def test_linear_conversion__with_bias(self):
62+
with kgb.spy_on(
63+
EdgeProgramToIRConverter.convert_program, call_original=True
64+
) as converter_spy:
65+
input_shape = (10, 32)
66+
model = LinearModule(bias=True)
67+
68+
edge_program = to_quantized_edge_program(
69+
model, input_shape
70+
).exported_program()
71+
72+
# Make sure that all nodes were delegated.
73+
assert not graph_contains_any_of_ops(
74+
graph=edge_program.graph, ops=[exir_ops.edge.aten.addmm.default]
75+
)
76+
assert any(
77+
"lowered_module" in node.name for node in edge_program.graph.nodes
78+
)
79+
80+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
81+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
82+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
83+
np.int8
84+
)
85+
convert_run_compare(
86+
exported_program,
87+
input_data,
88+
tfl_model=tflite_flatbuffers_model,
89+
)

backends/nxp/tests/ir/converter/node_converter/test_linear_converter.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)