@@ -276,55 +276,20 @@ def get_anchors(
276276        )
277277
278278
279- class  Conv1dPattern (QuantizationPattern ):
280-     def  partition_types (self ) ->  list [OpOverload ]:
281-         return  [torch .ops .aten .conv1d .default ]
282- 
283-     def  get_anchors (
284-         self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
285-     ) ->  PartitionAnchors :
286-         # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 
287-         conv1d_node  =  fused_partition [0 ].nodes [- 1 ]
288- 
289-         bias_qspec  =  DerivedQuantizationSpec (
290-             derived_from = [
291-                 (conv1d_node .args [0 ], conv1d_node ),
292-                 (conv1d_node .args [1 ], conv1d_node ),
293-             ],
294-             derive_qparams_fn = get_bias_qparams ,
295-             dtype = torch .int32 ,
296-             quant_min = - (2 ** 31 ),
297-             quant_max = 2 ** 31  -  1 ,
298-             qscheme = torch .per_tensor_affine ,
299-         )
300- 
301-         # Keep bias empty if not supplied 
302-         bias  =  []
303-         if  len (conv1d_node .args ) >  2  and  conv1d_node .args [2 ] is  not   None :
304-             bias  =  [(conv1d_node , NodeArgsIdx (2 ), bias_qspec )]
305- 
306-         return  PartitionAnchors (
307-             inputs = [(conv1d_node , NodeArgsIdx (0 ))],
308-             weights = [(conv1d_node , NodeArgsIdx (1 ))],
309-             # pyre-fixme[6]: Incompatible parameter type 
310-             biases = bias ,
311-             output = [(conv1d_node ,)],
312-         )
313- 
314- 
315- class  Conv2dPattern (QuantizationPattern ):
279+ class  ConvPattern (QuantizationPattern ):
280+     @abstractmethod  
316281    def  partition_types (self ) ->  list [OpOverload ]:
317-         return  [ torch . ops . aten . conv2d . default ] 
282+         pass 
318283
319284    def  get_anchors (
320285        self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
321286    ) ->  PartitionAnchors :
322-         conv2d_node  =  fused_partition [0 ].nodes [- 1 ]
287+         conv_node  =  fused_partition [0 ].nodes [- 1 ]
323288
324289        bias_quantization_qspec  =  DerivedQuantizationSpec (
325290            derived_from = [
326-                 (conv2d_node .args [0 ], conv2d_node ),
327-                 (conv2d_node .args [1 ], conv2d_node ),
291+                 (conv_node .args [0 ], conv_node ),
292+                 (conv_node .args [1 ], conv_node ),
328293            ],
329294            derive_qparams_fn = get_bias_qparams ,
330295            dtype = torch .int32 ,
@@ -346,17 +311,27 @@ def get_anchors(
346311
347312        # Keep bias empty if not supplied 
348313        bias  =  []
349-         if  len (conv2d_node .args ) >  2  and  conv2d_node .args [2 ] is  not   None :
350-             bias  =  [(conv2d_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
314+         if  len (conv_node .args ) >  2  and  conv_node .args [2 ] is  not   None :
315+             bias  =  [(conv_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
351316
352317        return  PartitionAnchors (
353-             inputs = [(conv2d_node , NodeArgsIdx (0 ))],
354-             weights = [(conv2d_node , NodeArgsIdx (1 ), weight_quantization_spec )],
318+             inputs = [(conv_node , NodeArgsIdx (0 ))],
319+             weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
355320            biases = bias ,
356-             output = [(conv2d_node ,)],
321+             output = [(conv_node ,)],
357322        )
358323
359324
325+ class  Conv1dPattern (ConvPattern ):
326+     def  partition_types (self ) ->  list [OpOverload ]:
327+         return  [torch .ops .aten .conv1d .default ]
328+ 
329+ 
330+ class  Conv2dPattern (ConvPattern ):
331+     def  partition_types (self ) ->  list [OpOverload ]:
332+         return  [torch .ops .aten .conv2d .default ]
333+ 
334+ 
360335class  DropoutPattern (SharedSpecPattern ):
361336    """ 
362337    Quantizer for Dropout operator. 
@@ -432,7 +407,6 @@ def partition_types(self) -> list[OpOverload]:
432407    def  get_anchors (
433408        self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
434409    ) ->  PartitionAnchors :
435-         # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 
436410        linear_node  =  fused_partition [0 ].nodes [- 1 ]
437411
438412        bias_qspec  =  DerivedQuantizationSpec (
@@ -455,7 +429,6 @@ def get_anchors(
455429        return  PartitionAnchors (
456430            inputs = [(linear_node , NodeArgsIdx (0 ))],
457431            weights = [(linear_node , NodeArgsIdx (1 ))],
458-             # pyre-fixme[6]: Incompatible parameter type 
459432            biases = bias ,
460433            output = [(linear_node ,)],
461434        )
@@ -479,6 +452,23 @@ def partition_types(self):
479452        return  [torch .ops .aten .mean .dim ]
480453
481454
455+ class  MmPattern (QuantizationPattern ):
456+     def  partition_types (self ) ->  list [OpOverload ]:
457+         return  [torch .ops .aten .mm .default ]
458+ 
459+     def  get_anchors (
460+         self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
461+     ) ->  PartitionAnchors :
462+         mm_node  =  fused_partition [0 ].nodes [- 1 ]
463+ 
464+         return  PartitionAnchors (
465+             inputs = [(mm_node , NodeArgsIdx (0 ))],
466+             weights = [(mm_node , NodeArgsIdx (1 ))],
467+             biases = [],
468+             output = [(mm_node ,)],
469+         )
470+ 
471+ 
482472class  PadPattern (SharedSpecPattern ):
483473    """ 
484474    Quantizer for Pad operator. 
@@ -552,33 +542,33 @@ def get_anchors(
552542        )
553543
554544
555- class  TanhPattern (QuantizationPattern ):
545+ class  SigmoidPattern (QuantizationPattern ):
556546    """ 
557-     Quantizer for Tanh  operator. 
547+     Quantizer for Sigmoid  operator. 
558548
559-     The quantization of Tanh  output is fixed to scale 1/128 , zero point 0 , dtype int8. 
549+     The quantization of Sigmoid  output is fixed to scale 1/256 , zero point -128 , dtype int8. 
560550    """ 
561551
562-     def  partition_types (self ):
563-         return  [torch .ops .aten .tanh .default ]
552+     def  partition_types (self )  ->   list [ OpOverload ] :
553+         return  [torch .ops .aten .sigmoid .default ]
564554
565555    def  get_anchors (
566556        self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
567557    ) ->  PartitionAnchors :
568558        return  get_anchors_for_fixed_quant_specs (
569-             fused_partition , scale = 1.0  /  128 .0 , zero_point = 0 
559+             fused_partition , scale = 1.0  /  256 .0 , zero_point = - 128 
570560        )
571561
572562
573- class  TanhInPlacePattern (QuantizationPattern ):
563+ class  TanhPattern (QuantizationPattern ):
574564    """ 
575-     Quantizer for inplace version of  Tanh operator (torch.tanh_) . 
565+     Quantizer for Tanh operator. 
576566
577567    The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8. 
578568    """ 
579569
580570    def  partition_types (self ):
581-         return  [torch .ops .aten .tanh_ .default ]
571+         return  [torch .ops .aten .tanh .default ]
582572
583573    def  get_anchors (
584574        self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
@@ -588,19 +578,19 @@ def get_anchors(
588578        )
589579
590580
591- class  SigmoidPattern (QuantizationPattern ):
581+ class  TanhInPlacePattern (QuantizationPattern ):
592582    """ 
593-     Quantizer for Sigmoid  operator. 
583+     Quantizer for inplace version of Tanh  operator (torch.tanh_) . 
594584
595-     The quantization of Sigmoid  output is fixed to scale 1/256 , zero point -128 , dtype int8. 
585+     The quantization of Tanh  output is fixed to scale 1/128 , zero point 0 , dtype int8. 
596586    """ 
597587
598-     def  partition_types (self )  ->   list [ OpOverload ] :
599-         return  [torch .ops .aten .sigmoid .default ]
588+     def  partition_types (self ):
589+         return  [torch .ops .aten .tanh_ .default ]
600590
601591    def  get_anchors (
602592        self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
603593    ) ->  PartitionAnchors :
604594        return  get_anchors_for_fixed_quant_specs (
605-             fused_partition , scale = 1.0  /  256 .0 , zero_point = - 128 
595+             fused_partition , scale = 1.0  /  128 .0 , zero_point = 0 
606596        )
0 commit comments