@@ -556,6 +556,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
556
556
patterns.onOp (
557
557
" QLinearMatMul" , 1 ,
558
558
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
559
+ Location loc = binder.getLoc ();
559
560
Torch::ValueTensorType resultType;
560
561
llvm::SmallVector<Value> operands;
561
562
if (binder.tensorOperands (operands, 8 ) ||
@@ -577,10 +578,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
577
578
return false ;
578
579
return true ;
579
580
};
580
- if (!check (aScale) || !check (aZp) || !check (bScale) || !check (bZp) ||
581
- !check (cScale) || !check (cScale))
581
+ if (!check (aScale) || !check (aZp) || !check (cScale) || !check (cZp))
582
582
return rewriter.notifyMatchFailure (
583
- binder.op , " not supported for non per-tensor quantization" );
583
+ binder.op , " input `a` and output not supported for non "
584
+ " per-tensor quantization" );
584
585
585
586
Value emptyList = rewriter.create <Torch::PrimListConstructOp>(
586
587
binder.getLoc (),
@@ -605,26 +606,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
605
606
};
606
607
607
608
aZp = extract (aZp);
608
- bZp = extract (bZp);
609
609
cZp = extract (cZp);
610
610
aScale = extract (aScale);
611
- bScale = extract (bScale);
612
611
cScale = extract (cScale);
613
612
614
- auto make = [&rewriter, &binder](Value v, Value scale,
615
- Value zp) -> Value {
613
+ auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
614
+ Value zp) -> Value {
616
615
auto ty = cast<Torch::ValueTensorType>(v.getType ());
617
616
auto newTy = getQTorchTypeFromTorchIntType (ty);
618
617
return rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
619
618
binder.getLoc (), newTy, v, scale, zp);
620
619
};
621
620
622
- a = make (a, aScale, aZp);
623
- b = make (b, bScale, bZp);
621
+ // The onnx's QLinearMatMul op allows per-column (per-channel)
622
+ // quantization only for the "b" tensor.
623
+ bool isPerColumnQuantization = false ;
624
+ auto bTy = dyn_cast<Torch::ValueTensorType>(b.getType ());
625
+ auto bScaleTy = dyn_cast<Torch::ValueTensorType>(bScale.getType ());
626
+ auto bZpTy = dyn_cast<Torch::ValueTensorType>(bZp.getType ());
627
+ if (!bTy || !bScaleTy || !bZpTy || !bTy.hasSizes () ||
628
+ !bScaleTy.hasSizes () || !bZpTy.hasSizes ())
629
+ return rewriter.notifyMatchFailure (
630
+ binder.op , " Expected b, b_scale, and b_zero_point "
631
+ " arguments to have sizes" );
632
+ ArrayRef<int64_t > bShape (bTy.getSizes ());
633
+ SmallVector<int64_t > bScaleShape (bScaleTy.getSizes ());
634
+ SmallVector<int64_t > bZpShape (bZpTy.getSizes ());
635
+ if (bScaleShape.size () == 0 ||
636
+ llvm::all_of (bScaleShape, [](int64_t s) { return s == 1 ; })) {
637
+ bZp = extract (bZp);
638
+ bScale = extract (bScale);
639
+ b = makePerTensor (b, bScale, bZp);
640
+ } else if ((bScaleShape.size () == 1 ||
641
+ bScaleShape.size () == bShape.size ()) &&
642
+ bScaleShape.back () != Torch::kUnknownSize &&
643
+ bScaleShape.back () == bShape.back ()) {
644
+ // Since the `QuantizedMatmulOp` in the downstream pipeline
645
+ // ("Linalg") does not support the per-column (per-channel)
646
+ // quantization for the arg `b`, hence for this particular case we
647
+ // perform the matmul over the dequantized inputs i.e., `a` and `b`
648
+ // instead of relying on the downstream pipeline to handle this. This
649
+ // code can be removed and made similar to the other paths in this
650
+ // lowering once the per-column (per-channel) quantization support is
651
+ // added in the downstream pipeline.
652
+ isPerColumnQuantization = true ;
653
+
654
+ auto aTy = dyn_cast<Torch::ValueTensorType>(a.getType ());
655
+ if (!aTy || !aTy.hasSizes ())
656
+ return rewriter.notifyMatchFailure (
657
+ binder.op , " Expected input argument `a` to have sizes" );
658
+
659
+ // Dequantizing the a
660
+ // a = a.to(dtype=torch.float32)
661
+ // a_dequant = (a - a_zero_point) * a_scale
662
+
663
+ // Converting the a tensor to float32 type.
664
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
665
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
666
+ Value float32Type = rewriter.create <Torch::ConstantIntOp>(
667
+ loc, rewriter.getI64IntegerAttr (/* float32Type*/ 6 ));
668
+ Type f32aType = rewriter.getType <Torch::ValueTensorType>(
669
+ aTy.getSizes (), rewriter.getF32Type ());
670
+ a = rewriter.create <Torch::AtenToDtypeOp>(loc, f32aType, a,
671
+ float32Type,
672
+ /* non_blocking=*/ cstFalse,
673
+ /* copy=*/ cstFalse,
674
+ /* memory_format=*/ none);
675
+
676
+ Value cstOne = rewriter.create <Torch::ConstantFloatOp>(
677
+ loc, rewriter.getF64FloatAttr (1.0 ));
678
+ a = rewriter.create <Torch::AtenSubScalarOp>(loc, f32aType, a, aZp,
679
+ cstOne);
680
+ a = rewriter.create <Torch::AtenMulScalarOp>(loc, f32aType, a, aScale);
681
+
682
+ // Dequantizing the b
683
+ // Shapes of the inputs are as follows:
684
+ // b = (B, K, N) or (K, N)
685
+ // b_scale = (B, 1, N) or (1, N) or (N)
686
+ // b_zero_point = (B, 1, N) or (1, N) or (N)
687
+ //
688
+ // We compute the dequantized `b` as follows:
689
+ // b = b.to(dtype=torch.float32)
690
+ // b_dequant = (b - b_zero_point) * b_scale
691
+
692
+ // Converting the b tensor to float32 type.
693
+ Type f32bType = rewriter.getType <Torch::ValueTensorType>(
694
+ bShape, rewriter.getF32Type ());
695
+ b = rewriter.create <Torch::AtenToDtypeOp>(loc, f32bType, b,
696
+ float32Type,
697
+ /* non_blocking=*/ cstFalse,
698
+ /* copy=*/ cstFalse,
699
+ /* memory_format=*/ none);
700
+
701
+ b = rewriter.create <Torch::AtenSubTensorOp>(loc, f32bType, b, bZp,
702
+ cstOne);
703
+ b = rewriter.create <Torch::AtenMulTensorOp>(loc, f32bType, b, bScale);
704
+ } else {
705
+ llvm_unreachable (
706
+ " Unidentified case for quantization for `b` argument of"
707
+ " Onnx.QLinearMatMul op" );
708
+ }
709
+
710
+ if (!isPerColumnQuantization)
711
+ a = makePerTensor (a, aScale, aZp);
712
+
713
+ Type cDtype =
714
+ isPerColumnQuantization
715
+ ? cast<Type>(rewriter.getF32Type ())
716
+ : cast<Type>(rewriter.getIntegerType (32 , /* issigned=*/ true ));
624
717
625
718
auto cTy = rewriter.getType <Torch::ValueTensorType>(
626
- resultType.getOptionalSizes (),
627
- rewriter.getIntegerType (32 , /* issigned=*/ true ));
719
+ resultType.getOptionalSizes (), cDtype);
628
720
629
721
Value c;
630
722
if (cTy.getSizes ().size () == 2 ) {
@@ -633,23 +725,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
633
725
c = rewriter.create <Torch::AtenBmmOp>(binder.getLoc (), cTy, a, b);
634
726
}
635
727
636
- cTy = rewriter.getType <Torch::ValueTensorType>(
637
- resultType.getOptionalSizes (),
638
- rewriter.getType <Torch::QInt32Type>());
728
+ if (!isPerColumnQuantization) {
729
+ cTy = rewriter.getType <Torch::ValueTensorType>(
730
+ resultType.getOptionalSizes (),
731
+ rewriter.getType <Torch::QInt32Type>());
639
732
640
- Value mmScale = rewriter.create <Torch::AtenMulFloatOp>(
641
- binder.getLoc (), rewriter.getType <Torch::FloatType>(), aScale,
642
- bScale);
643
- Value mmZp = rewriter.create <Torch::ConstantIntOp>(
644
- binder.getLoc (), rewriter.getType <Torch::IntType>(),
645
- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
646
- c = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
647
- binder.getLoc (), cTy, c, mmScale, mmZp);
648
- cTy = rewriter.getType <Torch::ValueTensorType>(
649
- resultType.getOptionalSizes (), rewriter.getF32Type ());
733
+ Value mmScale = rewriter.create <Torch::AtenMulFloatOp>(
734
+ binder.getLoc (), rewriter.getType <Torch::FloatType>(), aScale,
735
+ bScale);
736
+ Value mmZp = rewriter.create <Torch::ConstantIntOp>(
737
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
738
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
739
+ c = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
740
+ binder.getLoc (), cTy, c, mmScale, mmZp);
741
+ cTy = rewriter.getType <Torch::ValueTensorType>(
742
+ resultType.getOptionalSizes (), rewriter.getF32Type ());
743
+
744
+ c = rewriter.create <Torch::AtenDequantizeSelfOp>(binder.getLoc (), cTy,
745
+ c);
746
+ }
650
747
651
- c = rewriter.create <Torch::AtenDequantizeSelfOp>(binder.getLoc (), cTy,
652
- c);
653
748
cTy = dyn_cast<Torch::ValueTensorType>(
654
749
getQTorchTypeFromTorchIntType (resultType));
655
750
Value dtyVal = rewriter.create <Torch::ConstantIntOp>(
0 commit comments