@@ -276,55 +276,20 @@ def get_anchors(
276
276
)
277
277
278
278
279
- class Conv1dPattern (QuantizationPattern ):
280
- def partition_types (self ) -> list [OpOverload ]:
281
- return [torch .ops .aten .conv1d .default ]
282
-
283
- def get_anchors (
284
- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
285
- ) -> PartitionAnchors :
286
- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
287
- conv1d_node = fused_partition [0 ].nodes [- 1 ]
288
-
289
- bias_qspec = DerivedQuantizationSpec (
290
- derived_from = [
291
- (conv1d_node .args [0 ], conv1d_node ),
292
- (conv1d_node .args [1 ], conv1d_node ),
293
- ],
294
- derive_qparams_fn = get_bias_qparams ,
295
- dtype = torch .int32 ,
296
- quant_min = - (2 ** 31 ),
297
- quant_max = 2 ** 31 - 1 ,
298
- qscheme = torch .per_tensor_affine ,
299
- )
300
-
301
- # Keep bias empty if not supplied
302
- bias = []
303
- if len (conv1d_node .args ) > 2 and conv1d_node .args [2 ] is not None :
304
- bias = [(conv1d_node , NodeArgsIdx (2 ), bias_qspec )]
305
-
306
- return PartitionAnchors (
307
- inputs = [(conv1d_node , NodeArgsIdx (0 ))],
308
- weights = [(conv1d_node , NodeArgsIdx (1 ))],
309
- # pyre-fixme[6]: Incompatible parameter type
310
- biases = bias ,
311
- output = [(conv1d_node ,)],
312
- )
313
-
314
-
315
- class Conv2dPattern (QuantizationPattern ):
279
+ class ConvPattern (QuantizationPattern ):
280
+ @abstractmethod
316
281
def partition_types (self ) -> list [OpOverload ]:
317
- return [ torch . ops . aten . conv2d . default ]
282
+ pass
318
283
319
284
def get_anchors (
320
285
self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
321
286
) -> PartitionAnchors :
322
- conv2d_node = fused_partition [0 ].nodes [- 1 ]
287
+ conv_node = fused_partition [0 ].nodes [- 1 ]
323
288
324
289
bias_quantization_qspec = DerivedQuantizationSpec (
325
290
derived_from = [
326
- (conv2d_node .args [0 ], conv2d_node ),
327
- (conv2d_node .args [1 ], conv2d_node ),
291
+ (conv_node .args [0 ], conv_node ),
292
+ (conv_node .args [1 ], conv_node ),
328
293
],
329
294
derive_qparams_fn = get_bias_qparams ,
330
295
dtype = torch .int32 ,
@@ -346,17 +311,27 @@ def get_anchors(
346
311
347
312
# Keep bias empty if not supplied
348
313
bias = []
349
- if len (conv2d_node .args ) > 2 and conv2d_node .args [2 ] is not None :
350
- bias = [(conv2d_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
314
+ if len (conv_node .args ) > 2 and conv_node .args [2 ] is not None :
315
+ bias = [(conv_node , NodeArgsIdx (2 ), bias_quantization_qspec )]
351
316
352
317
return PartitionAnchors (
353
- inputs = [(conv2d_node , NodeArgsIdx (0 ))],
354
- weights = [(conv2d_node , NodeArgsIdx (1 ), weight_quantization_spec )],
318
+ inputs = [(conv_node , NodeArgsIdx (0 ))],
319
+ weights = [(conv_node , NodeArgsIdx (1 ), weight_quantization_spec )],
355
320
biases = bias ,
356
- output = [(conv2d_node ,)],
321
+ output = [(conv_node ,)],
357
322
)
358
323
359
324
325
+ class Conv1dPattern (ConvPattern ):
326
+ def partition_types (self ) -> list [OpOverload ]:
327
+ return [torch .ops .aten .conv1d .default ]
328
+
329
+
330
+ class Conv2dPattern (ConvPattern ):
331
+ def partition_types (self ) -> list [OpOverload ]:
332
+ return [torch .ops .aten .conv2d .default ]
333
+
334
+
360
335
class DropoutPattern (SharedSpecPattern ):
361
336
"""
362
337
Quantizer for Dropout operator.
@@ -432,7 +407,6 @@ def partition_types(self) -> list[OpOverload]:
432
407
def get_anchors (
433
408
self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
434
409
) -> PartitionAnchors :
435
- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
436
410
linear_node = fused_partition [0 ].nodes [- 1 ]
437
411
438
412
bias_qspec = DerivedQuantizationSpec (
@@ -455,7 +429,6 @@ def get_anchors(
455
429
return PartitionAnchors (
456
430
inputs = [(linear_node , NodeArgsIdx (0 ))],
457
431
weights = [(linear_node , NodeArgsIdx (1 ))],
458
- # pyre-fixme[6]: Incompatible parameter type
459
432
biases = bias ,
460
433
output = [(linear_node ,)],
461
434
)
@@ -479,6 +452,23 @@ def partition_types(self):
479
452
return [torch .ops .aten .mean .dim ]
480
453
481
454
455
+ class MmPattern (QuantizationPattern ):
456
+ def partition_types (self ) -> list [OpOverload ]:
457
+ return [torch .ops .aten .mm .default ]
458
+
459
+ def get_anchors (
460
+ self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
461
+ ) -> PartitionAnchors :
462
+ mm_node = fused_partition [0 ].nodes [- 1 ]
463
+
464
+ return PartitionAnchors (
465
+ inputs = [(mm_node , NodeArgsIdx (0 ))],
466
+ weights = [(mm_node , NodeArgsIdx (1 ))],
467
+ biases = [],
468
+ output = [(mm_node ,)],
469
+ )
470
+
471
+
482
472
class PadPattern (SharedSpecPattern ):
483
473
"""
484
474
Quantizer for Pad operator.
@@ -552,33 +542,33 @@ def get_anchors(
552
542
)
553
543
554
544
555
- class TanhPattern (QuantizationPattern ):
545
+ class SigmoidPattern (QuantizationPattern ):
556
546
"""
557
- Quantizer for Tanh operator.
547
+ Quantizer for Sigmoid operator.
558
548
559
- The quantization of Tanh output is fixed to scale 1/128 , zero point 0 , dtype int8.
549
+ The quantization of Sigmoid output is fixed to scale 1/256 , zero point -128 , dtype int8.
560
550
"""
561
551
562
- def partition_types (self ):
563
- return [torch .ops .aten .tanh .default ]
552
+ def partition_types (self ) -> list [ OpOverload ] :
553
+ return [torch .ops .aten .sigmoid .default ]
564
554
565
555
def get_anchors (
566
556
self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
567
557
) -> PartitionAnchors :
568
558
return get_anchors_for_fixed_quant_specs (
569
- fused_partition , scale = 1.0 / 128 .0 , zero_point = 0
559
+ fused_partition , scale = 1.0 / 256 .0 , zero_point = - 128
570
560
)
571
561
572
562
573
- class TanhInPlacePattern (QuantizationPattern ):
563
+ class TanhPattern (QuantizationPattern ):
574
564
"""
575
- Quantizer for inplace version of Tanh operator (torch.tanh_) .
565
+ Quantizer for Tanh operator.
576
566
577
567
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
578
568
"""
579
569
580
570
def partition_types (self ):
581
- return [torch .ops .aten .tanh_ .default ]
571
+ return [torch .ops .aten .tanh .default ]
582
572
583
573
def get_anchors (
584
574
self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
@@ -588,19 +578,19 @@ def get_anchors(
588
578
)
589
579
590
580
591
- class SigmoidPattern (QuantizationPattern ):
581
+ class TanhInPlacePattern (QuantizationPattern ):
592
582
"""
593
- Quantizer for Sigmoid operator.
583
+ Quantizer for inplace version of Tanh operator (torch.tanh_) .
594
584
595
- The quantization of Sigmoid output is fixed to scale 1/256 , zero point -128 , dtype int8.
585
+ The quantization of Tanh output is fixed to scale 1/128 , zero point 0 , dtype int8.
596
586
"""
597
587
598
- def partition_types (self ) -> list [ OpOverload ] :
599
- return [torch .ops .aten .sigmoid .default ]
588
+ def partition_types (self ):
589
+ return [torch .ops .aten .tanh_ .default ]
600
590
601
591
def get_anchors (
602
592
self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
603
593
) -> PartitionAnchors :
604
594
return get_anchors_for_fixed_quant_specs (
605
- fused_partition , scale = 1.0 / 256 .0 , zero_point = - 128
595
+ fused_partition , scale = 1.0 / 128 .0 , zero_point = 0
606
596
)
0 commit comments