Skip to content

Commit e4a2f86

Browse files
[ONNX] Add per-column quantization support for Onnx.QLinearMatMul op (#4080)
This commit extends the OnnxToTorch Lowering for Onnx.QLinearMatMul op by adding the support for per-column quantization for the input argument `b`. Since the `QuantizedMatmulOp` in the downstream pipeline ("Linalg") does not support the per-column (per-channel) quantization for the input arg `b`, hence we add the support by performing matmul over the dequantized inputs i.e., `a` and `b` and then quantizing the output. Fixes nod-ai/SHARK-ModelDev#916. --------- Signed-off-by: Vivek Khandelwal [email protected]
1 parent 1ea12de commit e4a2f86

File tree

2 files changed

+157
-26
lines changed

2 files changed

+157
-26
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 121 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
556556
patterns.onOp(
557557
"QLinearMatMul", 1,
558558
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
559+
Location loc = binder.getLoc();
559560
Torch::ValueTensorType resultType;
560561
llvm::SmallVector<Value> operands;
561562
if (binder.tensorOperands(operands, 8) ||
@@ -577,10 +578,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
577578
return false;
578579
return true;
579580
};
580-
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
581-
!check(cScale) || !check(cScale))
581+
if (!check(aScale) || !check(aZp) || !check(cScale) || !check(cZp))
582582
return rewriter.notifyMatchFailure(
583-
binder.op, "not supported for non per-tensor quantization");
583+
binder.op, "input `a` and output not supported for non "
584+
"per-tensor quantization");
584585

585586
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
586587
binder.getLoc(),
@@ -605,26 +606,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
605606
};
606607

607608
aZp = extract(aZp);
608-
bZp = extract(bZp);
609609
cZp = extract(cZp);
610610
aScale = extract(aScale);
611-
bScale = extract(bScale);
612611
cScale = extract(cScale);
613612

614-
auto make = [&rewriter, &binder](Value v, Value scale,
615-
Value zp) -> Value {
613+
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
614+
Value zp) -> Value {
616615
auto ty = cast<Torch::ValueTensorType>(v.getType());
617616
auto newTy = getQTorchTypeFromTorchIntType(ty);
618617
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
619618
binder.getLoc(), newTy, v, scale, zp);
620619
};
621620

622-
a = make(a, aScale, aZp);
623-
b = make(b, bScale, bZp);
621+
// The onnx's QLinearMatMul op allows per-column (per-channel)
622+
// quantization only for the "b" tensor.
623+
bool isPerColumnQuantization = false;
624+
auto bTy = dyn_cast<Torch::ValueTensorType>(b.getType());
625+
auto bScaleTy = dyn_cast<Torch::ValueTensorType>(bScale.getType());
626+
auto bZpTy = dyn_cast<Torch::ValueTensorType>(bZp.getType());
627+
if (!bTy || !bScaleTy || !bZpTy || !bTy.hasSizes() ||
628+
!bScaleTy.hasSizes() || !bZpTy.hasSizes())
629+
return rewriter.notifyMatchFailure(
630+
binder.op, "Expected b, b_scale, and b_zero_point "
631+
"arguments to have sizes");
632+
ArrayRef<int64_t> bShape(bTy.getSizes());
633+
SmallVector<int64_t> bScaleShape(bScaleTy.getSizes());
634+
SmallVector<int64_t> bZpShape(bZpTy.getSizes());
635+
if (bScaleShape.size() == 0 ||
636+
llvm::all_of(bScaleShape, [](int64_t s) { return s == 1; })) {
637+
bZp = extract(bZp);
638+
bScale = extract(bScale);
639+
b = makePerTensor(b, bScale, bZp);
640+
} else if ((bScaleShape.size() == 1 ||
641+
bScaleShape.size() == bShape.size()) &&
642+
bScaleShape.back() != Torch::kUnknownSize &&
643+
bScaleShape.back() == bShape.back()) {
644+
// Since the `QuantizedMatmulOp` in the downstream pipeline
645+
// ("Linalg") does not support the per-column (per-channel)
646+
// quantization for the arg `b`, hence for this particular case we
647+
// perform the matmul over the dequantized inputs i.e., `a` and `b`
648+
// instead of relying on the downstream pipeline to handle this. This
649+
// code can be removed and made similar to the other paths in this
650+
// lowering once the per-column (per-channel) quantization support is
651+
// added in the downstream pipeline.
652+
isPerColumnQuantization = true;
653+
654+
auto aTy = dyn_cast<Torch::ValueTensorType>(a.getType());
655+
if (!aTy || !aTy.hasSizes())
656+
return rewriter.notifyMatchFailure(
657+
binder.op, "Expected input argument `a` to have sizes");
658+
659+
// Dequantizing the a
660+
// a = a.to(dtype=torch.float32)
661+
// a_dequant = (a - a_zero_point) * a_scale
662+
663+
// Converting the a tensor to float32 type.
664+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
665+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
666+
Value float32Type = rewriter.create<Torch::ConstantIntOp>(
667+
loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6));
668+
Type f32aType = rewriter.getType<Torch::ValueTensorType>(
669+
aTy.getSizes(), rewriter.getF32Type());
670+
a = rewriter.create<Torch::AtenToDtypeOp>(loc, f32aType, a,
671+
float32Type,
672+
/*non_blocking=*/cstFalse,
673+
/*copy=*/cstFalse,
674+
/*memory_format=*/none);
675+
676+
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
677+
loc, rewriter.getF64FloatAttr(1.0));
678+
a = rewriter.create<Torch::AtenSubScalarOp>(loc, f32aType, a, aZp,
679+
cstOne);
680+
a = rewriter.create<Torch::AtenMulScalarOp>(loc, f32aType, a, aScale);
681+
682+
// Dequantizing the b
683+
// Shapes of the inputs are as follows:
684+
// b = (B, K, N) or (K, N)
685+
// b_scale = (B, 1, N) or (1, N) or (N)
686+
// b_zero_point = (B, 1, N) or (1, N) or (N)
687+
//
688+
// We compute the dequantized `b` as follows:
689+
// b = b.to(dtype=torch.float32)
690+
// b_dequant = (b - b_zero_point) * b_scale
691+
692+
// Converting the b tensor to float32 type.
693+
Type f32bType = rewriter.getType<Torch::ValueTensorType>(
694+
bShape, rewriter.getF32Type());
695+
b = rewriter.create<Torch::AtenToDtypeOp>(loc, f32bType, b,
696+
float32Type,
697+
/*non_blocking=*/cstFalse,
698+
/*copy=*/cstFalse,
699+
/*memory_format=*/none);
700+
701+
b = rewriter.create<Torch::AtenSubTensorOp>(loc, f32bType, b, bZp,
702+
cstOne);
703+
b = rewriter.create<Torch::AtenMulTensorOp>(loc, f32bType, b, bScale);
704+
} else {
705+
llvm_unreachable(
706+
"Unidentified case for quantization for `b` argument of"
707+
"Onnx.QLinearMatMul op");
708+
}
709+
710+
if (!isPerColumnQuantization)
711+
a = makePerTensor(a, aScale, aZp);
712+
713+
Type cDtype =
714+
isPerColumnQuantization
715+
? cast<Type>(rewriter.getF32Type())
716+
: cast<Type>(rewriter.getIntegerType(32, /*issigned=*/true));
624717

625718
auto cTy = rewriter.getType<Torch::ValueTensorType>(
626-
resultType.getOptionalSizes(),
627-
rewriter.getIntegerType(32, /*issigned=*/true));
719+
resultType.getOptionalSizes(), cDtype);
628720

629721
Value c;
630722
if (cTy.getSizes().size() == 2) {
@@ -633,23 +725,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
633725
c = rewriter.create<Torch::AtenBmmOp>(binder.getLoc(), cTy, a, b);
634726
}
635727

636-
cTy = rewriter.getType<Torch::ValueTensorType>(
637-
resultType.getOptionalSizes(),
638-
rewriter.getType<Torch::QInt32Type>());
728+
if (!isPerColumnQuantization) {
729+
cTy = rewriter.getType<Torch::ValueTensorType>(
730+
resultType.getOptionalSizes(),
731+
rewriter.getType<Torch::QInt32Type>());
639732

640-
Value mmScale = rewriter.create<Torch::AtenMulFloatOp>(
641-
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
642-
bScale);
643-
Value mmZp = rewriter.create<Torch::ConstantIntOp>(
644-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
645-
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
646-
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
647-
binder.getLoc(), cTy, c, mmScale, mmZp);
648-
cTy = rewriter.getType<Torch::ValueTensorType>(
649-
resultType.getOptionalSizes(), rewriter.getF32Type());
733+
Value mmScale = rewriter.create<Torch::AtenMulFloatOp>(
734+
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
735+
bScale);
736+
Value mmZp = rewriter.create<Torch::ConstantIntOp>(
737+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
738+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
739+
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
740+
binder.getLoc(), cTy, c, mmScale, mmZp);
741+
cTy = rewriter.getType<Torch::ValueTensorType>(
742+
resultType.getOptionalSizes(), rewriter.getF32Type());
743+
744+
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
745+
c);
746+
}
650747

651-
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
652-
c);
653748
cTy = dyn_cast<Torch::ValueTensorType>(
654749
getQTorchTypeFromTorchIntType(resultType));
655750
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,42 @@ func.func @test_qlinearmatmul_3D(%arg0: !torch.vtensor<[2,2,4],ui8>, %arg1: !tor
251251

252252
// -----
253253

254+
// CHECK-LABEL: func.func @test_qlinearmatmul_per_channel_quantization(
255+
// CHECK-SAME: %[[A:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,4096],ui8>,
256+
// CHECK-SAME: %[[A_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
257+
// CHECK-SAME: %[[A_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
258+
// CHECK-SAME: %[[B:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[4096,1024],si8>,
259+
// CHECK-SAME: %[[B_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1024],f32>,
260+
// CHECK-SAME: %[[B_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1024],si8>,
261+
// CHECK-SAME: %[[C_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
262+
// CHECK-SAME: %[[C_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1024],ui8>
263+
func.func @test_qlinearmatmul_per_channel_quantization(%arg0: !torch.vtensor<[1,4096],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[4096,1024],si8>, %arg4: !torch.vtensor<[1024],f32>, %arg5: !torch.vtensor<[1024],si8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1024],ui8> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.contrib = 1000 : si64, ai.onnx.ml = 3 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.nchwc = 1 : si64, com.ms.internal.nhwc = 19 : si64, org.pytorch.aten = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} {
264+
%0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,4096],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[4096,1024],si8>, !torch.vtensor<[1024],f32>, !torch.vtensor<[1024],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1024],ui8>
265+
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
266+
// CHECK: %[[AZP:.+]] = torch.aten.item %[[A_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
267+
// CHECK: %[[CZP:.+]] = torch.aten.item %[[C_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
268+
// CHECK: %[[ASCALE:.+]] = torch.aten.item %[[A_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
269+
// CHECK: %[[CSCALE:.+]] = torch.aten.item %[[C_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
270+
// CHECK: %[[NONE:.*]] = torch.constant.none
271+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
272+
// CHECK: %[[F32DTYPE:.*]] = torch.constant.int 6
273+
// CHECK: %[[A_F32:.*]] = torch.aten.to.dtype %[[A]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,4096],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,4096],f32>
274+
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
275+
// CHECK: %[[A_MINUS_ZP:.*]] = torch.aten.sub.Scalar %[[A_F32]], %[[AZP]], %[[ALPHA]] : !torch.vtensor<[1,4096],f32>, !torch.int, !torch.float -> !torch.vtensor<[1,4096],f32>
276+
// CHECK: %[[LHS:.*]] = torch.aten.mul.Scalar %[[A_MINUS_ZP]], %[[ASCALE]] : !torch.vtensor<[1,4096],f32>, !torch.float -> !torch.vtensor<[1,4096],f32>
277+
// CHECK: %[[B_F32:.*]] = torch.aten.to.dtype %[[B]], %[[F32DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[4096,1024],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4096,1024],f32>
278+
// CHECK: %[[B_MINUS_ZP:.*]] = torch.aten.sub.Tensor %[[B_F32]], %[[B_ZERO_POINT]], %[[ALPHA]] : !torch.vtensor<[4096,1024],f32>, !torch.vtensor<[1024],si8>, !torch.float -> !torch.vtensor<[4096,1024],f32>
279+
// CHECK: %[[RHS:.*]] = torch.aten.mul.Tensor %[[B_MINUS_ZP]], %[[B_SCALE]] : !torch.vtensor<[4096,1024],f32>, !torch.vtensor<[1024],f32> -> !torch.vtensor<[4096,1024],f32>
280+
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[1,4096],f32>, !torch.vtensor<[4096,1024],f32> -> !torch.vtensor<[1,1024],f32>
281+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 13
282+
// CHECK: %[[QO:.*]] = torch.aten.quantize_per_tensor %[[MM]], %[[CSCALE]], %[[CZP]], %[[DTYPE]] : !torch.vtensor<[1,1024],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024],!torch.quint8>
283+
// CHECK: %[[OUT:.*]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,1024],!torch.quint8> -> !torch.vtensor<[1,1024],ui8>
284+
// CHECK: return %[[OUT]] : !torch.vtensor<[1,1024],ui8>
285+
return %0 : !torch.vtensor<[1,1024],ui8>
286+
}
287+
288+
// -----
289+
254290
// CHECK-LABEL: func.func @test_reciprocal
255291
func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
256292
// CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>

0 commit comments

Comments
 (0)