Skip to content

Commit 388924c

Browse files
[ONNX] Add support for non-scalar zero points for Onnx.MatmulInteger op (#4036)
This commit adds the support for non-scalar zero points for Onnx->Torch lowering of MatmulInteger op. Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 11d0853 commit 388924c

File tree

2 files changed

+125
-22
lines changed

2 files changed

+125
-22
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -573,36 +573,92 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
573573
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
574574
}
575575

576+
bool isChannelQuantizationForLhs = false;
576577
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(lhsZp.getType())) {
577-
for (auto dim : zpTy.getSizes())
578-
if (dim != 1)
578+
auto lhsZpSize = zpTy.getSizes();
579+
if (lhsZpSize.size() == 0 ||
580+
llvm::all_of(lhsZpSize, [](int64_t d) { return d == 1; })) {
581+
lhsZp = rewriter.create<Torch::AtenItemOp>(
582+
binder.getLoc(), rewriter.getType<Torch::IntType>(), lhsZp);
583+
} else if (lhsZpSize.size() == 1) {
584+
auto lhsSize = lhsTy.getSizes();
585+
if (lhsSize.size() != 2 || lhsSize[0] != lhsZpSize[0])
579586
return failure();
580-
lhsZp = rewriter.create<Torch::AtenItemOp>(
581-
binder.getLoc(), rewriter.getType<Torch::IntType>(), lhsZp);
587+
isChannelQuantizationForLhs = true;
588+
} else {
589+
return failure();
590+
}
582591
}
583592

593+
bool isChannelQuantizationForRhs = false;
584594
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(rhsZp.getType())) {
585-
for (auto dim : zpTy.getSizes())
586-
if (dim != 1)
595+
auto rhsZpSize = zpTy.getSizes();
596+
if (rhsZpSize.size() == 0 ||
597+
llvm::all_of(rhsZpSize, [](int64_t d) { return d == 1; })) {
598+
rhsZp = rewriter.create<Torch::AtenItemOp>(
599+
binder.getLoc(), rewriter.getType<Torch::IntType>(), rhsZp);
600+
} else if (rhsZpSize.size() == 1) {
601+
auto rhsSize = rhsTy.getSizes();
602+
if (rhsSize.size() != 2 || rhsSize[1] != rhsZpSize[0])
587603
return failure();
588-
rhsZp = rewriter.create<Torch::AtenItemOp>(
589-
binder.getLoc(), rewriter.getType<Torch::IntType>(), rhsZp);
604+
isChannelQuantizationForRhs = true;
605+
} else {
606+
return failure();
607+
}
590608
}
591609

592-
Value scale = rewriter.create<Torch::ConstantFloatOp>(
593-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
594-
rewriter.getF64FloatAttr(1.0));
595-
596610
auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy);
597611
auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy);
598612

599613
if (!lhsQTy || !rhsQTy)
600614
return rewriter.notifyMatchFailure(binder.op, "failed to get qtype");
601615

602-
lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
603-
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
604-
rhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
605-
binder.getLoc(), rhsQTy, rhs, scale, rhsZp);
616+
Value f32Ty = rewriter.create<Torch::ConstantIntOp>(
617+
binder.getLoc(), rewriter.getI64IntegerAttr(
618+
(int64_t)torch_upstream::ScalarType::Float));
619+
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
620+
621+
if (isChannelQuantizationForLhs) {
622+
Value axis = rewriter.create<Torch::ConstantIntOp>(
623+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
624+
rewriter.getI64IntegerAttr(0));
625+
Torch::ValueTensorType lhsZpTy =
626+
dyn_cast<Torch::ValueTensorType>(lhsZp.getType());
627+
Type scaleTy = lhsZpTy.getWithSizesAndDtype(lhsZpTy.getSizes(),
628+
rewriter.getF32Type());
629+
Value scale = rewriter.create<Torch::AtenOnesLikeOp>(
630+
binder.getLoc(), scaleTy, /*self=*/lhsZp, f32Ty, /*layout=*/none,
631+
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
632+
lhs = rewriter.create<Torch::Aten_MakePerChannelQuantizedTensorOp>(
633+
binder.getLoc(), lhsQTy, lhs, scale, lhsZp, axis);
634+
} else {
635+
Value scale = rewriter.create<Torch::ConstantFloatOp>(
636+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
637+
rewriter.getF64FloatAttr(1.0));
638+
lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
639+
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
640+
}
641+
642+
if (isChannelQuantizationForRhs) {
643+
Value axis = rewriter.create<Torch::ConstantIntOp>(
644+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
645+
rewriter.getI64IntegerAttr(1));
646+
Torch::ValueTensorType rhsZpTy =
647+
dyn_cast<Torch::ValueTensorType>(rhsZp.getType());
648+
Type scaleTy = rhsZpTy.getWithSizesAndDtype(rhsZpTy.getSizes(),
649+
rewriter.getF32Type());
650+
Value scale = rewriter.create<Torch::AtenOnesLikeOp>(
651+
binder.getLoc(), scaleTy, /*self=*/rhsZp, f32Ty, /*layout=*/none,
652+
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
653+
rhs = rewriter.create<Torch::Aten_MakePerChannelQuantizedTensorOp>(
654+
binder.getLoc(), rhsQTy, rhs, scale, rhsZp, axis);
655+
} else {
656+
Value scale = rewriter.create<Torch::ConstantFloatOp>(
657+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
658+
rewriter.getF64FloatAttr(1.0));
659+
rhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
660+
binder.getLoc(), rhsQTy, rhs, scale, rhsZp);
661+
}
606662

607663
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(binder.op, resultType,
608664
lhs, rhs);

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,10 @@ func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vt
569569
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32>
570570
// CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
571571
// CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
572-
// CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00
573-
// CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
574-
// CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
572+
// CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
573+
// CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
574+
// CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
575+
// CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
575576
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
576577
// CHECK: return %[[MM]]
577578
return %0 : !torch.vtensor<[4,2],si32>
@@ -584,13 +585,59 @@ func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1:
584585
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[7,4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[7,4,2],si32>
585586
// CHECK: %[[LITEM:.+]] = torch.aten.item %arg2
586587
// CHECK: %[[RITEM:.+]] = torch.aten.item %arg3
587-
// CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00
588-
// CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8>
589-
// CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
588+
// CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00
589+
// CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8>
590+
// CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00
591+
// CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8>
590592
// CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]]
591593
// CHECK: return %[[MM]]
592594
return %0 : !torch.vtensor<[7,4,2],si32>
593595
}
596+
597+
// -----
598+
599+
// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_lhsZp(
600+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[16,2],ui8>,
601+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
602+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16],ui8>,
603+
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
604+
func.func @test_matmulinteger_non_scalar_lhsZp(%arg0: !torch.vtensor<[16, 2],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[16],ui8>, %arg3: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
605+
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],si8> -> !torch.int
606+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 6
607+
// CHECK: %[[VAL_6:.*]] = torch.constant.none
608+
// CHECK: %[[VAL_7:.*]] = torch.constant.int 0
609+
// CHECK: %[[VAL_8:.*]] = torch.aten.ones_like %[[VAL_2]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[16],ui8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[16],f32>
610+
// CHECK: %[[VAL_9:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_0]], %[[VAL_8]], %[[VAL_2]], %[[VAL_7]] : !torch.vtensor<[16,2],ui8>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],ui8>, !torch.int -> !torch.vtensor<[16,2],!torch.quint8>
611+
// CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e+00
612+
// CHECK: %[[VAL_11:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_4]] : !torch.vtensor<[2,768],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
613+
// CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_9]], %[[VAL_11]] : !torch.vtensor<[16,2],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[16,768],si32>
614+
// CHECK: return %[[VAL_12]] : !torch.vtensor<[16,768],si32>
615+
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[16,2],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[16],ui8>, !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32>
616+
return %0 : !torch.vtensor<[16,768],si32>
617+
}
618+
619+
// -----
620+
621+
// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_rhsZp(
622+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],ui8>,
623+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>,
624+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],ui8>,
625+
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_met
626+
func.func @test_matmulinteger_non_scalar_rhsZp(%arg0: !torch.vtensor<[?,?],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} {
627+
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],ui8> -> !torch.int
628+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 6
629+
// CHECK: %[[VAL_6:.*]] = torch.constant.none
630+
// CHECK: %[[VAL_7:.*]] = torch.constant.float 1.000000e+00
631+
// CHECK: %[[VAL_8:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_4]] : !torch.vtensor<[?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?],!torch.quint8>
632+
// CHECK: %[[VAL_9:.*]] = torch.constant.int 1
633+
// CHECK: %[[VAL_10:.*]] = torch.aten.ones_like %[[VAL_3]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[768],si8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[768],f32>
634+
// CHECK: %[[VAL_11:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_3]], %[[VAL_9]] : !torch.vtensor<[2,768],si8>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],!torch.qint8>
635+
// CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_8]], %[[VAL_11]] : !torch.vtensor<[?,?],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[?,768],si32>
636+
// CHECK: return %[[VAL_12]] : !torch.vtensor<[?,768],si32>
637+
%0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[?,?],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[],ui8>, !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32>
638+
return %0 : !torch.vtensor<[?,768],si32>
639+
}
640+
594641
// -----
595642

596643
// CHECK-LABEL: func.func @test_mul

0 commit comments

Comments
 (0)