@@ -276,60 +276,25 @@ def get_anchors(
276276 )
277277
278278
279- class Conv1dPattern (QuantizationPattern ):
280- def partition_types (self ) -> list [OpOverload ]:
281- return [torch .ops .aten .conv1d .default ]
282-
283- def get_anchors (
284- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
285- ) -> PartitionAnchors :
286- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
287- conv1d_node = fused_partition [0 ].nodes [- 1 ]
288-
289- bias_qspec = DerivedQuantizationSpec (
290- derived_from = [
291- (conv1d_node .args [0 ], conv1d_node ),
292- (conv1d_node .args [1 ], conv1d_node ),
293- ],
294- derive_qparams_fn = get_bias_qparams ,
295- dtype = torch .int32 ,
296- quant_min = - (2 ** 31 ),
297- quant_max = 2 ** 31 - 1 ,
298- qscheme = torch .per_tensor_affine ,
299- )
300-
301- # Keep bias empty if not supplied
302- bias = []
303- if len (conv1d_node .args ) > 2 and conv1d_node .args [2 ] is not None :
304- bias = [(conv1d_node , NodeArgsIdx (2 ), bias_qspec )]
305-
306- return PartitionAnchors (
307- inputs = [(conv1d_node , NodeArgsIdx (0 ))],
308- weights = [(conv1d_node , NodeArgsIdx (1 ))],
309- # pyre-fixme[6]: Incompatible parameter type
310- biases = bias ,
311- output = [(conv1d_node ,)],
312- )
313-
314-
315- class Conv2dPattern (QuantizationPattern ):
279+ class ConvPattern (QuantizationPattern ):
280+ @abstractmethod
316281 def partition_types (self ) -> list [OpOverload ]:
317- return [ torch . ops . aten . conv2d . default ]
282+ pass
318283
319284 def get_anchors (
320- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
285+ self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
321286 ) -> PartitionAnchors :
322- conv2d_node = fused_partition [0 ].nodes [- 1 ]
287+ conv_node = fused_partition [0 ].nodes [- 1 ]
323288
324289 bias_quantization_qspec = DerivedQuantizationSpec (
325290 derived_from = [
326- (conv2d_node .args [0 ], conv2d_node ),
327- (conv2d_node .args [1 ], conv2d_node ),
291+ (conv_node .args [0 ], conv_node ),
292+ (conv_node .args [1 ], conv_node ),
328293 ],
329294 derive_qparams_fn = get_bias_qparams ,
330295 dtype = torch .int32 ,
331- quant_min = - (2 ** 31 ) + 1 ,
332- quant_max = 2 ** 31 - 1 ,
296+ quant_min = - (2 ** 31 ) + 1 ,
297+ quant_max = 2 ** 31 - 1 ,
333298 qscheme = torch .per_channel_symmetric ,
334299 ch_axis = 0 ,
335300 )
@@ -346,16 +311,25 @@ def get_anchors(
346311
347312 # Keep bias empty if not supplied
348313 bias = []
349- if len (conv2d_node .args ) > 2 and conv2d_node .args [2 ] is not None :
350- bias = [(conv2d_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
314+ if len (conv_node .args ) > 2 and conv_node .args [2 ] is not None :
315+ bias = [(conv_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
351316
352317 return PartitionAnchors (
353- inputs = [(conv2d_node , NodeArgsIdx (0 ))],
354- weights = [(conv2d_node , NodeArgsIdx (1 ), weight_quantization_spec )],
318+ inputs = [(conv_node , NodeArgsIdx (0 ))],
319+ weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
355320 biases = bias ,
356- output = [(conv2d_node ,)],
321+ output = [(conv_node ,)],
357322 )
358323
324+ class Conv1dPattern (ConvPattern ):
325+ def partition_types (self ) -> list [OpOverload ]:
326+ return [torch .ops .aten .conv1d .default ]
327+
328+
329+ class Conv2dPattern (ConvPattern ):
330+ def partition_types (self ) -> list [OpOverload ]:
331+ return [torch .ops .aten .conv2d .default ]
332+
359333
360334class DropoutPattern (SharedSpecPattern ):
361335 """
@@ -432,7 +406,6 @@ def partition_types(self) -> list[OpOverload]:
432406 def get_anchors (
433407 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
434408 ) -> PartitionAnchors :
435- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
436409 linear_node = fused_partition [0 ].nodes [- 1 ]
437410
438411 bias_qspec = DerivedQuantizationSpec (
@@ -455,7 +428,6 @@ def get_anchors(
455428 return PartitionAnchors (
456429 inputs = [(linear_node , NodeArgsIdx (0 ))],
457430 weights = [(linear_node , NodeArgsIdx (1 ))],
458- # pyre-fixme[6]: Incompatible parameter type
459431 biases = bias ,
460432 output = [(linear_node ,)],
461433 )
@@ -479,6 +451,23 @@ def partition_types(self):
479451 return [torch .ops .aten .mean .dim ]
480452
481453
454+ class MmPattern (QuantizationPattern ):
455+ def partition_types (self ) -> list [OpOverload ]:
456+ return [torch .ops .aten .mm .default ]
457+
458+ def get_anchors (
459+ self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
460+ ) -> PartitionAnchors :
461+ mm_node = fused_partition [0 ].nodes [- 1 ]
462+
463+ return PartitionAnchors (
464+ inputs = [(mm_node , NodeArgsIdx (0 ))],
465+ weights = [(mm_node , NodeArgsIdx (1 ))],
466+ biases = [],
467+ output = [(mm_node ,)],
468+ )
469+
470+
482471class PadPattern (SharedSpecPattern ):
483472 """
484473 Quantizer for Pad operator.
@@ -552,33 +541,33 @@ def get_anchors(
552541 )
553542
554543
555- class TanhPattern (QuantizationPattern ):
544+ class SigmoidPattern (QuantizationPattern ):
556545 """
557- Quantizer for Tanh operator.
546+ Quantizer for Sigmoid operator.
558547
559- The quantization of Tanh output is fixed to scale 1/128 , zero point 0 , dtype int8.
548+ The quantization of Sigmoid output is fixed to scale 1/256 , zero point -128 , dtype int8.
560549 """
561550
562- def partition_types (self ):
563- return [torch .ops .aten .tanh .default ]
551+ def partition_types (self ) -> list [ OpOverload ] :
552+ return [torch .ops .aten .sigmoid .default ]
564553
565554 def get_anchors (
566- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
555+ self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
567556 ) -> PartitionAnchors :
568557 return get_anchors_for_fixed_quant_specs (
569- fused_partition , scale = 1.0 / 128 .0 , zero_point = 0
558+ fused_partition , scale = 1.0 / 256 .0 , zero_point = - 128
570559 )
571560
572561
573- class TanhInPlacePattern (QuantizationPattern ):
562+ class TanhPattern (QuantizationPattern ):
574563 """
575- Quantizer for inplace version of Tanh operator (torch.tanh_) .
564+ Quantizer for Tanh operator.
576565
577566 The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
578567 """
579568
580569 def partition_types (self ):
581- return [torch .ops .aten .tanh_ .default ]
570+ return [torch .ops .aten .tanh .default ]
582571
583572 def get_anchors (
584573 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
@@ -588,19 +577,21 @@ def get_anchors(
588577 )
589578
590579
591- class SigmoidPattern (QuantizationPattern ):
580+ class TanhInPlacePattern (QuantizationPattern ):
592581 """
593- Quantizer for Sigmoid operator.
582+ Quantizer for inplace version of Tanh operator (torch.tanh_) .
594583
595- The quantization of Sigmoid output is fixed to scale 1/256 , zero point -128 , dtype int8.
584+ The quantization of Tanh output is fixed to scale 1/128 , zero point 0 , dtype int8.
596585 """
597586
598- def partition_types (self ) -> list [ OpOverload ] :
599- return [torch .ops .aten .sigmoid .default ]
587+ def partition_types (self ):
588+ return [torch .ops .aten .tanh_ .default ]
600589
601590 def get_anchors (
602591 self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
603592 ) -> PartitionAnchors :
604593 return get_anchors_for_fixed_quant_specs (
605- fused_partition , scale = 1.0 / 256 .0 , zero_point = - 128
594+ fused_partition , scale = 1.0 / 128 .0 , zero_point = 0
606595 )
596+
597+
0 commit comments