@@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]:
524
524
525
525
526
526
class SoftmaxPattern (QuantizationPattern ):
527
-
528
527
def partition_types (self ) -> List [OpOverload ]:
529
528
return [torch .ops .aten ._softmax .default ]
530
529
@@ -546,3 +545,57 @@ def get_anchors(
546
545
547
546
def replacement_op (self ) -> OpOverload :
548
547
return torch .ops .cadence .quantized_softmax .default
548
+
549
+
550
+ class MixedW8A32LinearPattern (QuantizationPattern ):
551
+ def partition_types (self ) -> List [OpOverload ]:
552
+ return [torch .ops .aten .linear .default ]
553
+
554
+ def get_anchors (
555
+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
556
+ ) -> Tuple [PartitionAnchors , fx .Node ]:
557
+ # pyre-ignore[29]
558
+ linear_layer = fused_partition [0 ].nodes [- 1 ]
559
+
560
+ # Bail if the arguments have different shapes than expected
561
+ if len (linear_layer .args ) != 3 or len (linear_layer .kwargs ) > 0 :
562
+ return (
563
+ PartitionAnchors (
564
+ empty = True ,
565
+ ),
566
+ linear_layer ,
567
+ )
568
+
569
+ input_node = linear_layer .args [0 ]
570
+ input_shape = input_node .meta ["tensor_meta" ].shape
571
+
572
+ # Bail if the weights are not multiple of 4 (SIMD)
573
+ if input_shape [- 1 ] % 4 != 0 :
574
+ return (
575
+ PartitionAnchors (
576
+ empty = True ,
577
+ ),
578
+ linear_layer ,
579
+ )
580
+ # Currenly only supporting vector-matrix multiplication
581
+ if len (input_shape ) > 0 and input_shape [- 2 ] != 1 :
582
+ return (
583
+ PartitionAnchors (
584
+ empty = True ,
585
+ ),
586
+ linear_layer ,
587
+ )
588
+
589
+ return (
590
+ PartitionAnchors (
591
+ inputs = [],
592
+ weights = [(linear_layer , 1 )],
593
+ biases = [(linear_layer , 2 )],
594
+ output = [],
595
+ others = [(linear_layer , 0 )],
596
+ ),
597
+ linear_layer ,
598
+ )
599
+
600
+ def replacement_op (self ) -> OpOverload :
601
+ return torch .ops .cadence .quantized_w8a32_linear .default
0 commit comments