@@ -524,7 +524,6 @@ def partition_types(self) -> List[OpOverload]:
524524
525525
526526class SoftmaxPattern (QuantizationPattern ):
527-
528527 def partition_types (self ) -> List [OpOverload ]:
529528 return [torch .ops .aten ._softmax .default ]
530529
@@ -546,3 +545,57 @@ def get_anchors(
546545
547546 def replacement_op (self ) -> OpOverload :
548547 return torch .ops .cadence .quantized_softmax .default
548+
549+
550+ class MixedW8A32LinearPattern (QuantizationPattern ):
551+ def partition_types (self ) -> List [OpOverload ]:
552+ return [torch .ops .aten .linear .default ]
553+
554+ def get_anchors (
555+ self , gm : fx .GraphModule , fused_partition : List [fx .GraphModule ]
556+ ) -> Tuple [PartitionAnchors , fx .Node ]:
557+ # pyre-ignore[29]
558+ linear_layer = fused_partition [0 ].nodes [- 1 ]
559+
560+ # Bail if the arguments have different shapes than expected
561+ if len (linear_layer .args ) != 3 or len (linear_layer .kwargs ) > 0 :
562+ return (
563+ PartitionAnchors (
564+ empty = True ,
565+ ),
566+ linear_layer ,
567+ )
568+
569+ input_node = linear_layer .args [0 ]
570+ input_shape = input_node .meta ["tensor_meta" ].shape
571+
572+ # Bail if the weights are not multiple of 4 (SIMD)
573+ if input_shape [- 1 ] % 4 != 0 :
574+ return (
575+ PartitionAnchors (
576+ empty = True ,
577+ ),
578+ linear_layer ,
579+ )
580+ # Currenly only supporting vector-matrix multiplication
581+ if len (input_shape ) > 0 and input_shape [- 2 ] != 1 :
582+ return (
583+ PartitionAnchors (
584+ empty = True ,
585+ ),
586+ linear_layer ,
587+ )
588+
589+ return (
590+ PartitionAnchors (
591+ inputs = [],
592+ weights = [(linear_layer , 1 )],
593+ biases = [(linear_layer , 2 )],
594+ output = [],
595+ others = [(linear_layer , 0 )],
596+ ),
597+ linear_layer ,
598+ )
599+
600+ def replacement_op (self ) -> OpOverload :
601+ return torch .ops .cadence .quantized_w8a32_linear .default
0 commit comments