Skip to content

Commit 913436a

Browse files
authored
Adding Tests for CadenceFusedConvReluQuantizer (#16358)
1 parent c730feb commit 913436a

File tree

1 file changed

+142
-27
lines changed

1 file changed

+142
-27
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 142 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,18 @@
4646

4747
# Type alias for graph builder functions.
4848
# These functions take a test instance and return a graph module and the target op node.
49+
# For fused patterns (e.g., conv+relu), an optional third element specifies the node
50+
# whose args contain the quantized inputs (e.g., conv node for conv+relu fusion).
4951
GraphBuilderFn = Callable[
50-
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
52+
["QuantizerAnnotationTest"],
53+
tuple[torch.fx.GraphModule, torch.fx.Node]
54+
| tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node],
5155
]
5256

5357

5458
# Quantizers intentionally excluded from annotation testing.
5559
# These should be explicitly justified when added.
5660
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
57-
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
5861
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
5962
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
6063
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
@@ -64,14 +67,15 @@
6467
# Test case definitions for quantizer annotation tests.
6568
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6669
# Adding a new quantizer test only requires adding a tuple to this list.
70+
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6771
QUANTIZER_ANNOTATION_TEST_CASES: list[
6872
tuple[
6973
str,
7074
GraphBuilderFn,
7175
CadenceQuantizer,
7276
OpOverload,
7377
QuantizationSpec,
74-
list[QuantizationSpec],
78+
list[QuantizationSpec | None],
7579
]
7680
] = [
7781
(
@@ -192,6 +196,26 @@
192196
# For relu: only input_activation
193197
[qconfig_A8W8.input_activation],
194198
),
199+
(
200+
"default_addmm_A8W8",
201+
lambda self: self._build_addmm_graph(),
202+
CadenceDefaultQuantizer(),
203+
torch.ops.aten.addmm.default,
204+
qconfig_A8W8.output_activation,
205+
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
206+
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
207+
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
208+
),
209+
# CadenceFusedConvReluQuantizer test cases
210+
(
211+
"fused_conv2d_relu_A8W8sym",
212+
lambda self: self._build_conv2d_relu_graph(),
213+
CadenceFusedConvReluQuantizer(),
214+
torch.ops.aten.relu.default,
215+
qconfig_A8W8sym.output_activation,
216+
# For fused conv2d+relu: [input_activation, weight] from conv2d node
217+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
218+
),
195219
]
196220

197221
# Derive the set of tested quantizer classes from the test cases.
@@ -408,6 +432,77 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
408432
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
409433
return gm, relu_nodes[0]
410434

435+
def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
436+
"""Build a simple graph with an addmm operation."""
437+
builder = GraphBuilder()
438+
# addmm: bias + (mat1 @ mat2)
439+
# args: (bias, mat1, mat2)
440+
bias = builder.placeholder("bias", torch.randn(5))
441+
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
442+
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
443+
addmm = builder.call_operator(
444+
op=torch.ops.aten.addmm.default,
445+
args=(bias, mat1, mat2),
446+
meta=NodeMetadata(
447+
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
448+
),
449+
)
450+
builder.output([addmm])
451+
gm = builder.get_graph_module()
452+
453+
addmm_nodes = gm.graph.find_nodes(
454+
op="call_function",
455+
target=torch.ops.aten.addmm.default,
456+
)
457+
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
458+
return gm, addmm_nodes[0]
459+
460+
def _build_conv2d_relu_graph(
461+
self,
462+
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
463+
"""Build a graph with a conv2d followed by relu (fused pattern).
464+
465+
Returns:
466+
A tuple of (graph_module, relu_node, conv_node).
467+
The relu_node is the target node where the annotation is placed.
468+
The conv_node is the input source node whose args contain the quantized inputs.
469+
"""
470+
builder = GraphBuilder()
471+
# Input shape: (batch, in_channels, height, width)
472+
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
473+
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
474+
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
475+
conv2d = builder.call_operator(
476+
op=torch.ops.aten.conv2d.default,
477+
args=(x, weight),
478+
meta=NodeMetadata(
479+
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
480+
),
481+
)
482+
relu = builder.call_operator(
483+
op=torch.ops.aten.relu.default,
484+
args=(conv2d,),
485+
meta=NodeMetadata(
486+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
487+
),
488+
)
489+
builder.output([relu])
490+
gm = builder.get_graph_module()
491+
492+
relu_nodes = gm.graph.find_nodes(
493+
op="call_function",
494+
target=torch.ops.aten.relu.default,
495+
)
496+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
497+
498+
conv2d_nodes = gm.graph.find_nodes(
499+
op="call_function",
500+
target=torch.ops.aten.conv2d.default,
501+
)
502+
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
503+
504+
return gm, relu_nodes[0], conv2d_nodes[0]
505+
411506
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
412507
def test_quantizer_annotation(
413508
self,
@@ -416,36 +511,56 @@ def test_quantizer_annotation(
416511
quantizer: CadenceQuantizer,
417512
target: OpOverload,
418513
expected_output_qspec: QuantizationSpec,
419-
expected_input_qspecs: list[QuantizationSpec],
514+
expected_input_qspecs: list[QuantizationSpec | None],
420515
) -> None:
421516
"""Parameterized test for quantizer annotations."""
422-
gm, op_node = graph_builder_fn(self)
517+
result = graph_builder_fn(self)
518+
# Handle both 2-element and 3-element returns from graph builders.
519+
# For fused patterns, the 3rd element specifies the node whose args
520+
# contain the quantized inputs (e.g., conv node for conv+relu fusion).
521+
if len(result) == 3:
522+
gm = result[0]
523+
output_node = result[1]
524+
input_source_node = result[2]
525+
else:
526+
gm = result[0]
527+
output_node = result[1]
528+
input_source_node = output_node
423529

424530
quantizer.annotate(gm)
425531

426-
annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
427-
self.assertTrue(annotation._annotated)
428-
429-
# Verify output annotation
430-
self.assertEqual(annotation.output_qspec, expected_output_qspec)
431-
432-
# Verify input annotations
433-
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
434-
for i, (input_node, input_qspec) in enumerate(
435-
annotation.input_qspec_map.items()
436-
):
437-
expected_arg = op_node.args[i]
438-
assert isinstance(expected_arg, torch.fx.Node)
439-
self.assertEqual(
440-
input_node,
441-
expected_arg,
442-
f"Input node mismatch at index {i}",
443-
)
444-
self.assertEqual(
445-
input_qspec,
446-
expected_input_qspecs[i],
447-
f"Input qspec mismatch at index {i}",
532+
# Verify output annotation (always on the output node)
533+
output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY]
534+
self.assertTrue(output_annotation._annotated)
535+
self.assertEqual(output_annotation.output_qspec, expected_output_qspec)
536+
537+
# Verify input annotations (on the input source node, which may differ for fused patterns)
538+
input_annotation: QuantizationAnnotation = input_source_node.meta[
539+
Q_ANNOTATION_KEY
540+
]
541+
self.assertEqual(
542+
len(input_annotation.input_qspec_map), len(expected_input_qspecs)
543+
)
544+
for input_node, input_qspec in input_annotation.input_qspec_map.items():
545+
# Find the index of this input node in the input source node's args
546+
arg_index = None
547+
args = input_source_node.args
548+
assert isinstance(args, tuple)
549+
for i, arg in enumerate(args):
550+
if arg is input_node:
551+
arg_index = i
552+
break
553+
self.assertIsNotNone(
554+
arg_index,
555+
f"Input node {input_node} not found in input_source_node.args",
448556
)
557+
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
558+
if expected_input_qspecs[arg_index] is not None:
559+
self.assertEqual(
560+
input_qspec,
561+
expected_input_qspecs[arg_index],
562+
f"Input qspec mismatch at arg index {arg_index}",
563+
)
449564

450565
def test_all_quantizers_have_annotation_tests(self) -> None:
451566
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""

0 commit comments

Comments
 (0)