@@ -276,55 +276,20 @@ def get_anchors(
276276 )
277277
278278
279- class Conv1dPattern (QuantizationPattern ):
280- def partition_types (self ) -> list [OpOverload ]:
281- return [torch .ops .aten .conv1d .default ]
282-
283- def get_anchors (
284- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
285- ) -> PartitionAnchors :
286- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
287- conv1d_node = fused_partition [0 ].nodes [- 1 ]
288-
289- bias_qspec = DerivedQuantizationSpec (
290- derived_from = [
291- (conv1d_node .args [0 ], conv1d_node ),
292- (conv1d_node .args [1 ], conv1d_node ),
293- ],
294- derive_qparams_fn = get_bias_qparams ,
295- dtype = torch .int32 ,
296- quant_min = - (2 ** 31 ),
297- quant_max = 2 ** 31 - 1 ,
298- qscheme = torch .per_tensor_affine ,
299- )
300-
301- # Keep bias empty if not supplied
302- bias = []
303- if len (conv1d_node .args ) > 2 and conv1d_node .args [2 ] is not None :
304- bias = [(conv1d_node , NodeArgsIdx (2 ), bias_qspec )]
305-
306- return PartitionAnchors (
307- inputs = [(conv1d_node , NodeArgsIdx (0 ))],
308- weights = [(conv1d_node , NodeArgsIdx (1 ))],
309- # pyre-fixme[6]: Incompatible parameter type
310- biases = bias ,
311- output = [(conv1d_node ,)],
312- )
313-
314-
315- class Conv2dPattern (QuantizationPattern ):
279+ class ConvPattern (QuantizationPattern ):
280+ @abstractmethod
316281 def partition_types (self ) -> list [OpOverload ]:
317- return [ torch . ops . aten . conv2d . default ]
282+ pass
318283
319284 def get_anchors (
320285 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
321286 ) -> PartitionAnchors :
322- conv2d_node = fused_partition [0 ].nodes [- 1 ]
287+ conv_node = fused_partition [0 ].nodes [- 1 ]
323288
324289 bias_quantization_qspec = DerivedQuantizationSpec (
325290 derived_from = [
326- (conv2d_node .args [0 ], conv2d_node ),
327- (conv2d_node .args [1 ], conv2d_node ),
291+ (conv_node .args [0 ], conv_node ),
292+ (conv_node .args [1 ], conv_node ),
328293 ],
329294 derive_qparams_fn = get_bias_qparams ,
330295 dtype = torch .int32 ,
@@ -346,17 +311,27 @@ def get_anchors(
346311
347312 # Keep bias empty if not supplied
348313 bias = []
349- if len (conv2d_node .args ) > 2 and conv2d_node .args [2 ] is not None :
350- bias = [(conv2d_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
314+ if len (conv_node .args ) > 2 and conv_node .args [2 ] is not None :
315+ bias = [(conv_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
351316
352317 return PartitionAnchors (
353- inputs = [(conv2d_node , NodeArgsIdx (0 ))],
354- weights = [(conv2d_node , NodeArgsIdx (1 ), weight_quantization_spec )],
318+ inputs = [(conv_node , NodeArgsIdx (0 ))],
319+ weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
355320 biases = bias ,
356- output = [(conv2d_node ,)],
321+ output = [(conv_node ,)],
357322 )
358323
359324
325+ class Conv1dPattern (ConvPattern ):
326+ def partition_types (self ) -> list [OpOverload ]:
327+ return [torch .ops .aten .conv1d .default ]
328+
329+
330+ class Conv2dPattern (ConvPattern ):
331+ def partition_types (self ) -> list [OpOverload ]:
332+ return [torch .ops .aten .conv2d .default ]
333+
334+
360335class DropoutPattern (SharedSpecPattern ):
361336 """
362337 Quantizer for Dropout operator.
@@ -432,7 +407,6 @@ def partition_types(self) -> list[OpOverload]:
432407 def get_anchors (
433408 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
434409 ) -> PartitionAnchors :
435- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
436410 linear_node = fused_partition [0 ].nodes [- 1 ]
437411
438412 bias_qspec = DerivedQuantizationSpec (
@@ -455,7 +429,6 @@ def get_anchors(
455429 return PartitionAnchors (
456430 inputs = [(linear_node , NodeArgsIdx (0 ))],
457431 weights = [(linear_node , NodeArgsIdx (1 ))],
458- # pyre-fixme[6]: Incompatible parameter type
459432 biases = bias ,
460433 output = [(linear_node ,)],
461434 )
@@ -479,6 +452,23 @@ def partition_types(self):
479452 return [torch .ops .aten .mean .dim ]
480453
481454
455+ class MmPattern (QuantizationPattern ):
456+ def partition_types (self ) -> list [OpOverload ]:
457+ return [torch .ops .aten .mm .default ]
458+
459+ def get_anchors (
460+ self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
461+ ) -> PartitionAnchors :
462+ mm_node = fused_partition [0 ].nodes [- 1 ]
463+
464+ return PartitionAnchors (
465+ inputs = [(mm_node , NodeArgsIdx (0 ))],
466+ weights = [(mm_node , NodeArgsIdx (1 ))],
467+ biases = [],
468+ output = [(mm_node ,)],
469+ )
470+
471+
482472class PadPattern (SharedSpecPattern ):
483473 """
484474 Quantizer for Pad operator.
@@ -552,33 +542,33 @@ def get_anchors(
552542 )
553543
554544
555- class TanhPattern (QuantizationPattern ):
545+ class SigmoidPattern (QuantizationPattern ):
556546 """
557- Quantizer for Tanh operator.
547+ Quantizer for Sigmoid operator.
558548
559- The quantization of Tanh output is fixed to scale 1/128 , zero point 0 , dtype int8.
549+ The quantization of Sigmoid output is fixed to scale 1/256 , zero point -128 , dtype int8.
560550 """
561551
562- def partition_types (self ):
563- return [torch .ops .aten .tanh .default ]
552+ def partition_types (self ) -> list [ OpOverload ] :
553+ return [torch .ops .aten .sigmoid .default ]
564554
565555 def get_anchors (
566556 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
567557 ) -> PartitionAnchors :
568558 return get_anchors_for_fixed_quant_specs (
569- fused_partition , scale = 1.0 / 128 .0 , zero_point = 0
559+ fused_partition , scale = 1.0 / 256 .0 , zero_point = - 128
570560 )
571561
572562
573- class TanhInPlacePattern (QuantizationPattern ):
563+ class TanhPattern (QuantizationPattern ):
574564 """
575- Quantizer for inplace version of Tanh operator (torch.tanh_) .
565+ Quantizer for Tanh operator.
576566
577567 The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
578568 """
579569
580570 def partition_types (self ):
581- return [torch .ops .aten .tanh_ .default ]
571+ return [torch .ops .aten .tanh .default ]
582572
583573 def get_anchors (
584574 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
@@ -588,19 +578,19 @@ def get_anchors(
588578 )
589579
590580
591- class SigmoidPattern (QuantizationPattern ):
581+ class TanhInPlacePattern (QuantizationPattern ):
592582 """
593- Quantizer for Sigmoid operator.
583+ Quantizer for inplace version of Tanh operator (torch.tanh_) .
594584
595- The quantization of Sigmoid output is fixed to scale 1/256 , zero point -128 , dtype int8.
585+ The quantization of Tanh output is fixed to scale 1/128 , zero point 0 , dtype int8.
596586 """
597587
598- def partition_types (self ) -> list [ OpOverload ] :
599- return [torch .ops .aten .sigmoid .default ]
588+ def partition_types (self ):
589+ return [torch .ops .aten .tanh_ .default ]
600590
601591 def get_anchors (
602592 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
603593 ) -> PartitionAnchors :
604594 return get_anchors_for_fixed_quant_specs (
605- fused_partition , scale = 1.0 / 256 .0 , zero_point = - 128
595+ fused_partition , scale = 1.0 / 128 .0 , zero_point = 0
606596 )
0 commit comments