Skip to content

Commit 0e2a6fb

Browse files
authored
Adding Test To Ensure All Future Quantizers Are Tested
Differential Revision: D88055443 Pull Request resolved: #16099
1 parent e804065 commit 0e2a6fb

File tree

2 files changed

+153
-9
lines changed

2 files changed

+153
-9
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ python_unittest(
640640
],
641641
typing = True,
642642
deps = [
643+
"fbsource//third-party/pypi/parameterized:parameterized",
643644
"//caffe2:torch",
644645
"//executorch/backends/cadence/aot:graph_builder",
645646
"//executorch/backends/cadence/aot/quantizer:quantizer",

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 152 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,100 @@
66

77
# pyre-strict
88

9+
import inspect
910
import unittest
11+
from typing import Callable
1012

1113
import torch
1214
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
15+
from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module
1316
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
1417

1518
from executorch.backends.cadence.aot.quantizer.quantizer import (
1619
CadenceAtenQuantizer,
1720
CadenceDefaultQuantizer,
21+
CadenceFusedConvReluQuantizer,
22+
CadenceNopQuantizer,
23+
CadenceQuantizer,
1824
CadenceW8A32MixedQuantizer,
25+
CadenceWakeWordQuantizer,
26+
CadenceWith16BitConvActivationsQuantizer,
27+
CadenceWith16BitLinearActivationsQuantizer,
1928
CadenceWith16BitMatmulActivationsQuantizer,
29+
CadenceWithLayerNormQuantizer,
30+
CadenceWithSoftmaxQuantizer,
2031
qconfig_A16,
2132
qconfig_A8W8,
2233
)
2334
from executorch.exir.pass_base import NodeMetadata
35+
from parameterized import parameterized
36+
from torch._ops import OpOverload
2437
from torchao.quantization.pt2e.quantizer.quantizer import (
2538
Q_ANNOTATION_KEY,
2639
QuantizationAnnotation,
40+
QuantizationSpec,
2741
)
2842

43+
# Type alias for graph builder functions.
44+
# These functions take a test instance and return a graph module and the target op node.
45+
GraphBuilderFn = Callable[
46+
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
47+
]
48+
49+
50+
# Quantizers intentionally excluded from annotation testing.
51+
# These should be explicitly justified when added.
52+
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
53+
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
54+
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
55+
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
56+
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
57+
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
58+
CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage
59+
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
60+
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
61+
}
62+
63+
64+
# Test case definitions for quantizer annotation tests.
65+
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
66+
# Adding a new quantizer test only requires adding a tuple to this list.
67+
QUANTIZER_ANNOTATION_TEST_CASES: list[
68+
tuple[
69+
str,
70+
GraphBuilderFn,
71+
CadenceQuantizer,
72+
OpOverload,
73+
QuantizationSpec,
74+
list[QuantizationSpec],
75+
]
76+
] = [
77+
(
78+
"matmul_A16",
79+
lambda self: self._build_matmul_graph(),
80+
CadenceWith16BitMatmulActivationsQuantizer(),
81+
torch.ops.aten.matmul.default,
82+
qconfig_A16.output_activation,
83+
# For matmul, both inputs are activations
84+
[qconfig_A16.input_activation, qconfig_A16.input_activation],
85+
),
86+
(
87+
"linear_A16",
88+
lambda self: self._build_linear_graph(),
89+
CadenceWith16BitLinearActivationsQuantizer(),
90+
torch.ops.aten.linear.default,
91+
qconfig_A16.output_activation,
92+
# For linear: [input_activation, weight]
93+
[qconfig_A16.input_activation, qconfig_A16.weight],
94+
),
95+
]
96+
97+
# Derive the set of tested quantizer classes from the test cases.
98+
# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
99+
TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = {
100+
type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES
101+
}
102+
29103

30104
class QuantizerAnnotationTest(unittest.TestCase):
31105
"""Unit tests for verifying quantizer annotations are correctly applied."""
@@ -52,21 +126,90 @@ def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
52126
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
53127
return gm, matmul_nodes[0]
54128

55-
def test_matmul_16bit_quantizer_annotation(self) -> None:
56-
"""Test that CadenceWith16BitMatmulActivationsQuantizer correctly annotates matmul."""
57-
gm, matmul_node = self._build_matmul_graph()
129+
def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
130+
"""Build a simple graph with a linear operation (no bias)."""
131+
builder = GraphBuilder()
132+
x = builder.placeholder("x", torch.randn(1, 10))
133+
weight = builder.placeholder("weight", torch.randn(5, 10))
134+
linear = builder.call_operator(
135+
op=torch.ops.aten.linear.default,
136+
args=(x, weight),
137+
meta=NodeMetadata(
138+
{"source_fn_stack": [("linear", torch.ops.aten.linear.default)]}
139+
),
140+
)
141+
builder.output([linear])
142+
gm = builder.get_graph_module()
143+
144+
linear_nodes = gm.graph.find_nodes(
145+
op="call_function",
146+
target=torch.ops.aten.linear.default,
147+
)
148+
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
149+
return gm, linear_nodes[0]
150+
151+
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
152+
def test_quantizer_annotation(
153+
self,
154+
name: str,
155+
graph_builder_fn: GraphBuilderFn,
156+
quantizer: CadenceQuantizer,
157+
target: OpOverload,
158+
expected_output_qspec: QuantizationSpec,
159+
expected_input_qspecs: list[QuantizationSpec],
160+
) -> None:
161+
"""Parameterized test for quantizer annotations."""
162+
gm, op_node = graph_builder_fn(self)
58163

59-
quantizer = CadenceWith16BitMatmulActivationsQuantizer()
60164
quantizer.annotate(gm)
61165

62-
annotation: QuantizationAnnotation = matmul_node.meta[Q_ANNOTATION_KEY]
166+
annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
63167
self.assertTrue(annotation._annotated)
64168

65-
self.assertEqual(annotation.output_qspec, qconfig_A16.output_activation)
169+
# Verify output annotation
170+
self.assertEqual(annotation.output_qspec, expected_output_qspec)
171+
172+
# Verify input annotations
173+
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
174+
for i, (input_node, input_qspec) in enumerate(
175+
annotation.input_qspec_map.items()
176+
):
177+
expected_arg = op_node.args[i]
178+
assert isinstance(expected_arg, torch.fx.Node)
179+
self.assertEqual(
180+
input_node,
181+
expected_arg,
182+
f"Input node mismatch at index {i}",
183+
)
184+
self.assertEqual(
185+
input_qspec,
186+
expected_input_qspecs[i],
187+
f"Input qspec mismatch at index {i}",
188+
)
66189

67-
self.assertEqual(len(annotation.input_qspec_map), 2)
68-
for _, input_qspec in annotation.input_qspec_map.items():
69-
self.assertEqual(input_qspec, qconfig_A16.input_activation)
190+
def test_all_quantizers_have_annotation_tests(self) -> None:
191+
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
192+
# Get all CadenceQuantizer subclasses defined in the quantizer module
193+
all_quantizers: set[type[CadenceQuantizer]] = set()
194+
for _, obj in inspect.getmembers(quantizer_module, inspect.isclass):
195+
if (
196+
issubclass(obj, CadenceQuantizer)
197+
and obj is not CadenceQuantizer
198+
and obj.__module__ == quantizer_module.__name__
199+
):
200+
all_quantizers.add(obj)
201+
202+
# Check for missing tests
203+
untested = (
204+
all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING
205+
)
206+
if untested:
207+
untested_names = sorted(cls.__name__ for cls in untested)
208+
self.fail(
209+
f"The following CadenceQuantizer subclasses are not tested in "
210+
f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: "
211+
f"{untested_names}. Please add test cases or explicitly exclude them."
212+
)
70213

71214

72215
class QuantizerOpsPreserveTest(unittest.TestCase):

0 commit comments

Comments
 (0)