66
77# pyre-strict
88
9+ import inspect
910import unittest
11+ from typing import Callable
1012
1113import torch
1214from executorch .backends .cadence .aot .graph_builder import GraphBuilder
15+ from executorch .backends .cadence .aot .quantizer import quantizer as quantizer_module
1316from executorch .backends .cadence .aot .quantizer .patterns import AddmmPattern
1417
1518from executorch .backends .cadence .aot .quantizer .quantizer import (
1619 CadenceAtenQuantizer ,
1720 CadenceDefaultQuantizer ,
21+ CadenceFusedConvReluQuantizer ,
22+ CadenceNopQuantizer ,
23+ CadenceQuantizer ,
1824 CadenceW8A32MixedQuantizer ,
25+ CadenceWakeWordQuantizer ,
26+ CadenceWith16BitConvActivationsQuantizer ,
27+ CadenceWith16BitLinearActivationsQuantizer ,
1928 CadenceWith16BitMatmulActivationsQuantizer ,
29+ CadenceWithLayerNormQuantizer ,
30+ CadenceWithSoftmaxQuantizer ,
2031 qconfig_A16 ,
2132 qconfig_A8W8 ,
2233)
2334from executorch .exir .pass_base import NodeMetadata
35+ from parameterized import parameterized
36+ from torch ._ops import OpOverload
2437from torchao .quantization .pt2e .quantizer .quantizer import (
2538 Q_ANNOTATION_KEY ,
2639 QuantizationAnnotation ,
40+ QuantizationSpec ,
2741)
2842
43+ # Type alias for graph builder functions.
44+ # These functions take a test instance and return a graph module and the target op node.
45+ GraphBuilderFn = Callable [
46+ ["QuantizerAnnotationTest" ], tuple [torch .fx .GraphModule , torch .fx .Node ]
47+ ]
48+
49+
50+ # Quantizers intentionally excluded from annotation testing.
51+ # These should be explicitly justified when added.
52+ EXCLUDED_FROM_ANNOTATION_TESTING : set [type [CadenceQuantizer ]] = {
53+ CadenceDefaultQuantizer , # TODO: T247438143 Add test coverage
54+ CadenceFusedConvReluQuantizer , # TODO: T247438151 Add test coverage
55+ CadenceNopQuantizer , # No-op quantizer, doesn't annotate anything
56+ CadenceW8A32MixedQuantizer , # TODO: T247438158 Add test coverage
57+ CadenceWakeWordQuantizer , # TODO: T247438162 Add test coverage
58+ CadenceWith16BitConvActivationsQuantizer , # TODO: T247438221 Add test coverage
59+ CadenceWithLayerNormQuantizer , # TODO: T247438410 Add test coverage
60+ CadenceWithSoftmaxQuantizer , # TODO: T247438418 Add test coverage
61+ }
62+
63+
64+ # Test case definitions for quantizer annotation tests.
65+ # Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
66+ # Adding a new quantizer test only requires adding a tuple to this list.
67+ QUANTIZER_ANNOTATION_TEST_CASES : list [
68+ tuple [
69+ str ,
70+ GraphBuilderFn ,
71+ CadenceQuantizer ,
72+ OpOverload ,
73+ QuantizationSpec ,
74+ list [QuantizationSpec ],
75+ ]
76+ ] = [
77+ (
78+ "matmul_A16" ,
79+ lambda self : self ._build_matmul_graph (),
80+ CadenceWith16BitMatmulActivationsQuantizer (),
81+ torch .ops .aten .matmul .default ,
82+ qconfig_A16 .output_activation ,
83+ # For matmul, both inputs are activations
84+ [qconfig_A16 .input_activation , qconfig_A16 .input_activation ],
85+ ),
86+ (
87+ "linear_A16" ,
88+ lambda self : self ._build_linear_graph (),
89+ CadenceWith16BitLinearActivationsQuantizer (),
90+ torch .ops .aten .linear .default ,
91+ qconfig_A16 .output_activation ,
92+ # For linear: [input_activation, weight]
93+ [qconfig_A16 .input_activation , qconfig_A16 .weight ],
94+ ),
95+ ]
96+
97+ # Derive the set of tested quantizer classes from the test cases.
98+ # This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
99+ TESTED_QUANTIZER_CLASSES : set [type [CadenceQuantizer ]] = {
100+ type (case [2 ]) for case in QUANTIZER_ANNOTATION_TEST_CASES
101+ }
102+
29103
30104class QuantizerAnnotationTest (unittest .TestCase ):
31105 """Unit tests for verifying quantizer annotations are correctly applied."""
@@ -52,21 +126,90 @@ def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
52126 self .assertEqual (len (matmul_nodes ), 1 , "Should find exactly one matmul node" )
53127 return gm , matmul_nodes [0 ]
54128
55- def test_matmul_16bit_quantizer_annotation (self ) -> None :
56- """Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
57- gm , matmul_node = self ._build_matmul_graph ()
129+ def _build_linear_graph (self ) -> tuple [torch .fx .GraphModule , torch .fx .Node ]:
130+ """Build a simple graph with a linear operation (no bias)."""
131+ builder = GraphBuilder ()
132+ x = builder .placeholder ("x" , torch .randn (1 , 10 ))
133+ weight = builder .placeholder ("weight" , torch .randn (5 , 10 ))
134+ linear = builder .call_operator (
135+ op = torch .ops .aten .linear .default ,
136+ args = (x , weight ),
137+ meta = NodeMetadata (
138+ {"source_fn_stack" : [("linear" , torch .ops .aten .linear .default )]}
139+ ),
140+ )
141+ builder .output ([linear ])
142+ gm = builder .get_graph_module ()
143+
144+ linear_nodes = gm .graph .find_nodes (
145+ op = "call_function" ,
146+ target = torch .ops .aten .linear .default ,
147+ )
148+ self .assertEqual (len (linear_nodes ), 1 , "Should find exactly one linear node" )
149+ return gm , linear_nodes [0 ]
150+
151+ @parameterized .expand (QUANTIZER_ANNOTATION_TEST_CASES )
152+ def test_quantizer_annotation (
153+ self ,
154+ name : str ,
155+ graph_builder_fn : GraphBuilderFn ,
156+ quantizer : CadenceQuantizer ,
157+ target : OpOverload ,
158+ expected_output_qspec : QuantizationSpec ,
159+ expected_input_qspecs : list [QuantizationSpec ],
160+ ) -> None :
161+ """Parameterized test for quantizer annotations."""
162+ gm , op_node = graph_builder_fn (self )
58163
59- quantizer = CadenceWith16BitMatmulActivationsQuantizer ()
60164 quantizer .annotate (gm )
61165
62- annotation : QuantizationAnnotation = matmul_node .meta [Q_ANNOTATION_KEY ]
166+ annotation : QuantizationAnnotation = op_node .meta [Q_ANNOTATION_KEY ]
63167 self .assertTrue (annotation ._annotated )
64168
65- self .assertEqual (annotation .output_qspec , qconfig_A16 .output_activation )
169+ # Verify output annotation
170+ self .assertEqual (annotation .output_qspec , expected_output_qspec )
171+
172+ # Verify input annotations
173+ self .assertEqual (len (annotation .input_qspec_map ), len (expected_input_qspecs ))
174+ for i , (input_node , input_qspec ) in enumerate (
175+ annotation .input_qspec_map .items ()
176+ ):
177+ expected_arg = op_node .args [i ]
178+ assert isinstance (expected_arg , torch .fx .Node )
179+ self .assertEqual (
180+ input_node ,
181+ expected_arg ,
182+ f"Input node mismatch at index { i } " ,
183+ )
184+ self .assertEqual (
185+ input_qspec ,
186+ expected_input_qspecs [i ],
187+ f"Input qspec mismatch at index { i } " ,
188+ )
66189
67- self .assertEqual (len (annotation .input_qspec_map ), 2 )
68- for _ , input_qspec in annotation .input_qspec_map .items ():
69- self .assertEqual (input_qspec , qconfig_A16 .input_activation )
190+ def test_all_quantizers_have_annotation_tests (self ) -> None :
191+ """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
192+ # Get all CadenceQuantizer subclasses defined in the quantizer module
193+ all_quantizers : set [type [CadenceQuantizer ]] = set ()
194+ for _ , obj in inspect .getmembers (quantizer_module , inspect .isclass ):
195+ if (
196+ issubclass (obj , CadenceQuantizer )
197+ and obj is not CadenceQuantizer
198+ and obj .__module__ == quantizer_module .__name__
199+ ):
200+ all_quantizers .add (obj )
201+
202+ # Check for missing tests
203+ untested = (
204+ all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING
205+ )
206+ if untested :
207+ untested_names = sorted (cls .__name__ for cls in untested )
208+ self .fail (
209+ f"The following CadenceQuantizer subclasses are not tested in "
210+ f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: "
211+ f"{ untested_names } . Please add test cases or explicitly exclude them."
212+ )
70213
71214
72215class QuantizerOpsPreserveTest (unittest .TestCase ):
0 commit comments