4646
4747# Type alias for graph builder functions.
4848# These functions take a test instance and return a graph module and the target op node.
49+ # For fused patterns (e.g., conv+relu), an optional third element specifies the node
50+ # whose args contain the quantized inputs (e.g., conv node for conv+relu fusion).
4951GraphBuilderFn = Callable [
50- ["QuantizerAnnotationTest" ], tuple [torch .fx .GraphModule , torch .fx .Node ]
52+ ["QuantizerAnnotationTest" ],
53+ tuple [torch .fx .GraphModule , torch .fx .Node ]
54+ | tuple [torch .fx .GraphModule , torch .fx .Node , torch .fx .Node ],
5155]
5256
5357
5458# Quantizers intentionally excluded from annotation testing.
5559# These should be explicitly justified when added.
5660EXCLUDED_FROM_ANNOTATION_TESTING : set [type [CadenceQuantizer ]] = {
57- CadenceFusedConvReluQuantizer , # TODO: T247438151 Add test coverage
5861 CadenceNopQuantizer , # No-op quantizer, doesn't annotate anything
5962 CadenceW8A32MixedQuantizer , # TODO: T247438158 Add test coverage
6063 CadenceRmsNormNopQuantizer , # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
6467# Test case definitions for quantizer annotation tests.
6568# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6669# Adding a new quantizer test only requires adding a tuple to this list.
70+ # Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6771QUANTIZER_ANNOTATION_TEST_CASES : list [
6872 tuple [
6973 str ,
7074 GraphBuilderFn ,
7175 CadenceQuantizer ,
7276 OpOverload ,
7377 QuantizationSpec ,
74- list [QuantizationSpec ],
78+ list [QuantizationSpec | None ],
7579 ]
7680] = [
7781 (
192196 # For relu: only input_activation
193197 [qconfig_A8W8 .input_activation ],
194198 ),
199+ (
200+ "default_addmm_A8W8" ,
201+ lambda self : self ._build_addmm_graph (),
202+ CadenceDefaultQuantizer (),
203+ torch .ops .aten .addmm .default ,
204+ qconfig_A8W8 .output_activation ,
205+ # For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
206+ # Use None to skip comparison for bias since it's a DerivedQuantizationSpec
207+ [None , qconfig_A8W8 .input_activation , qconfig_A8W8 .weight ],
208+ ),
209+ # CadenceFusedConvReluQuantizer test cases
210+ (
211+ "fused_conv2d_relu_A8W8sym" ,
212+ lambda self : self ._build_conv2d_relu_graph (),
213+ CadenceFusedConvReluQuantizer (),
214+ torch .ops .aten .relu .default ,
215+ qconfig_A8W8sym .output_activation ,
216+ # For fused conv2d+relu: [input_activation, weight] from conv2d node
217+ [qconfig_A8W8sym .input_activation , qconfig_A8W8sym .weight ],
218+ ),
195219]
196220
197221# Derive the set of tested quantizer classes from the test cases.
@@ -408,6 +432,77 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
408432 self .assertEqual (len (relu_nodes ), 1 , "Should find exactly one relu node" )
409433 return gm , relu_nodes [0 ]
410434
435+ def _build_addmm_graph (self ) -> tuple [torch .fx .GraphModule , torch .fx .Node ]:
436+ """Build a simple graph with an addmm operation."""
437+ builder = GraphBuilder ()
438+ # addmm: bias + (mat1 @ mat2)
439+ # args: (bias, mat1, mat2)
440+ bias = builder .placeholder ("bias" , torch .randn (5 ))
441+ mat1 = builder .placeholder ("mat1" , torch .randn (1 , 10 ))
442+ mat2 = builder .placeholder ("mat2" , torch .randn (10 , 5 ))
443+ addmm = builder .call_operator (
444+ op = torch .ops .aten .addmm .default ,
445+ args = (bias , mat1 , mat2 ),
446+ meta = NodeMetadata (
447+ {"source_fn_stack" : [("addmm" , torch .ops .aten .addmm .default )]}
448+ ),
449+ )
450+ builder .output ([addmm ])
451+ gm = builder .get_graph_module ()
452+
453+ addmm_nodes = gm .graph .find_nodes (
454+ op = "call_function" ,
455+ target = torch .ops .aten .addmm .default ,
456+ )
457+ self .assertEqual (len (addmm_nodes ), 1 , "Should find exactly one addmm node" )
458+ return gm , addmm_nodes [0 ]
459+
460+ def _build_conv2d_relu_graph (
461+ self ,
462+ ) -> tuple [torch .fx .GraphModule , torch .fx .Node , torch .fx .Node ]:
463+ """Build a graph with a conv2d followed by relu (fused pattern).
464+
465+ Returns:
466+ A tuple of (graph_module, relu_node, conv_node).
467+ The relu_node is the target node where the annotation is placed.
468+ The conv_node is the input source node whose args contain the quantized inputs.
469+ """
470+ builder = GraphBuilder ()
471+ # Input shape: (batch, in_channels, height, width)
472+ x = builder .placeholder ("x" , torch .randn (1 , 3 , 8 , 8 ))
473+ # Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
474+ weight = builder .placeholder ("weight" , torch .randn (6 , 3 , 3 , 3 ))
475+ conv2d = builder .call_operator (
476+ op = torch .ops .aten .conv2d .default ,
477+ args = (x , weight ),
478+ meta = NodeMetadata (
479+ {"source_fn_stack" : [("conv2d" , torch .ops .aten .conv2d .default )]}
480+ ),
481+ )
482+ relu = builder .call_operator (
483+ op = torch .ops .aten .relu .default ,
484+ args = (conv2d ,),
485+ meta = NodeMetadata (
486+ {"source_fn_stack" : [("relu" , torch .ops .aten .relu .default )]}
487+ ),
488+ )
489+ builder .output ([relu ])
490+ gm = builder .get_graph_module ()
491+
492+ relu_nodes = gm .graph .find_nodes (
493+ op = "call_function" ,
494+ target = torch .ops .aten .relu .default ,
495+ )
496+ self .assertEqual (len (relu_nodes ), 1 , "Should find exactly one relu node" )
497+
498+ conv2d_nodes = gm .graph .find_nodes (
499+ op = "call_function" ,
500+ target = torch .ops .aten .conv2d .default ,
501+ )
502+ self .assertEqual (len (conv2d_nodes ), 1 , "Should find exactly one conv2d node" )
503+
504+ return gm , relu_nodes [0 ], conv2d_nodes [0 ]
505+
411506 @parameterized .expand (QUANTIZER_ANNOTATION_TEST_CASES )
412507 def test_quantizer_annotation (
413508 self ,
@@ -416,36 +511,56 @@ def test_quantizer_annotation(
416511 quantizer : CadenceQuantizer ,
417512 target : OpOverload ,
418513 expected_output_qspec : QuantizationSpec ,
419- expected_input_qspecs : list [QuantizationSpec ],
514+ expected_input_qspecs : list [QuantizationSpec | None ],
420515 ) -> None :
421516 """Parameterized test for quantizer annotations."""
422- gm , op_node = graph_builder_fn (self )
517+ result = graph_builder_fn (self )
518+ # Handle both 2-element and 3-element returns from graph builders.
519+ # For fused patterns, the 3rd element specifies the node whose args
520+ # contain the quantized inputs (e.g., conv node for conv+relu fusion).
521+ if len (result ) == 3 :
522+ gm = result [0 ]
523+ output_node = result [1 ]
524+ input_source_node = result [2 ]
525+ else :
526+ gm = result [0 ]
527+ output_node = result [1 ]
528+ input_source_node = output_node
423529
424530 quantizer .annotate (gm )
425531
426- annotation : QuantizationAnnotation = op_node .meta [Q_ANNOTATION_KEY ]
427- self .assertTrue (annotation ._annotated )
428-
429- # Verify output annotation
430- self .assertEqual (annotation .output_qspec , expected_output_qspec )
431-
432- # Verify input annotations
433- self .assertEqual (len (annotation .input_qspec_map ), len (expected_input_qspecs ))
434- for i , (input_node , input_qspec ) in enumerate (
435- annotation .input_qspec_map .items ()
436- ):
437- expected_arg = op_node .args [i ]
438- assert isinstance (expected_arg , torch .fx .Node )
439- self .assertEqual (
440- input_node ,
441- expected_arg ,
442- f"Input node mismatch at index { i } " ,
443- )
444- self .assertEqual (
445- input_qspec ,
446- expected_input_qspecs [i ],
447- f"Input qspec mismatch at index { i } " ,
532+ # Verify output annotation (always on the output node)
533+ output_annotation : QuantizationAnnotation = output_node .meta [Q_ANNOTATION_KEY ]
534+ self .assertTrue (output_annotation ._annotated )
535+ self .assertEqual (output_annotation .output_qspec , expected_output_qspec )
536+
537+ # Verify input annotations (on the input source node, which may differ for fused patterns)
538+ input_annotation : QuantizationAnnotation = input_source_node .meta [
539+ Q_ANNOTATION_KEY
540+ ]
541+ self .assertEqual (
542+ len (input_annotation .input_qspec_map ), len (expected_input_qspecs )
543+ )
544+ for input_node , input_qspec in input_annotation .input_qspec_map .items ():
545+ # Find the index of this input node in the input source node's args
546+ arg_index = None
547+ args = input_source_node .args
548+ assert isinstance (args , tuple )
549+ for i , arg in enumerate (args ):
550+ if arg is input_node :
551+ arg_index = i
552+ break
553+ self .assertIsNotNone (
554+ arg_index ,
555+ f"Input node { input_node } not found in input_source_node.args" ,
448556 )
557+ # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
558+ if expected_input_qspecs [arg_index ] is not None :
559+ self .assertEqual (
560+ input_qspec ,
561+ expected_input_qspecs [arg_index ],
562+ f"Input qspec mismatch at arg index { arg_index } " ,
563+ )
449564
450565 def test_all_quantizers_have_annotation_tests (self ) -> None :
451566 """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
0 commit comments