Skip to content

Commit 1fe59c8

Browse files
authored
Adding Test for CadenceWith16BitMatmulActivationsQuantizer
Differential Revision: D88053808 Pull Request resolved: pytorch#16089
1 parent 141174d commit 1fe59c8

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,9 @@ python_unittest(
641641
typing = True,
642642
deps = [
643643
"//caffe2:torch",
644+
"//executorch/backends/cadence/aot:graph_builder",
644645
"//executorch/backends/cadence/aot/quantizer:quantizer",
646+
"//executorch/exir:pass_base",
647+
"//pytorch/ao:torchao",
645648
],
646649
)

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,64 @@
99
import unittest
1010

1111
import torch
12+
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
1213
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
1314

1415
from executorch.backends.cadence.aot.quantizer.quantizer import (
1516
CadenceAtenQuantizer,
1617
CadenceDefaultQuantizer,
1718
CadenceW8A32MixedQuantizer,
19+
CadenceWith16BitMatmulActivationsQuantizer,
20+
qconfig_A16,
1821
qconfig_A8W8,
1922
)
23+
from executorch.exir.pass_base import NodeMetadata
24+
from torchao.quantization.pt2e.quantizer.quantizer import (
25+
Q_ANNOTATION_KEY,
26+
QuantizationAnnotation,
27+
)
28+
29+
30+
class QuantizerAnnotationTest(unittest.TestCase):
31+
"""Unit tests for verifying quantizer annotations are correctly applied."""
32+
33+
def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
34+
"""Build a simple graph with a matmul operation."""
35+
builder = GraphBuilder()
36+
x = builder.placeholder("x", torch.randn(4, 8))
37+
y = builder.placeholder("y", torch.randn(8, 4))
38+
matmul = builder.call_operator(
39+
op=torch.ops.aten.matmul.default,
40+
args=(x, y),
41+
meta=NodeMetadata(
42+
{"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]}
43+
),
44+
)
45+
builder.output([matmul])
46+
gm = builder.get_graph_module()
47+
48+
matmul_nodes = gm.graph.find_nodes(
49+
op="call_function",
50+
target=torch.ops.aten.matmul.default,
51+
)
52+
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
53+
return gm, matmul_nodes[0]
54+
55+
def test_matmul_16bit_quantizer_annotation(self) -> None:
56+
"""Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
57+
gm, matmul_node = self._build_matmul_graph()
58+
59+
quantizer = CadenceWith16BitMatmulActivationsQuantizer()
60+
quantizer.annotate(gm)
61+
62+
annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY]
63+
self.assertTrue(annotation._annotated)
64+
65+
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)
66+
67+
self.assertEqual(len(annotation.input_qspec_map), 2)
68+
for _, input_qspec in annotation.input_qspec_map.items():
69+
self.assertEqual(input_qspec, qconfig_A16.input_activation)
2070

2171

2272
class QuantizerOpsPreserveTest(unittest.TestCase):

0 commit comments

Comments
 (0)