Skip to content

Commit cfb213e

Browse files
[ONNX] Add support for Onnx.QLinearAdd op (#4113)
This commit adds the Onnx->Torch lowering for [Onnx.QLinearAdd](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqlinearadd) op. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 2d654bb commit cfb213e

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,4 +455,127 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
455455
rewriter.replaceOp(binder.op, {attention, presentKey, presentValue});
456456
return success();
457457
});
458+
patterns.onOp(
459+
"QLinearAdd", 1,
460+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
461+
Location loc = binder.getLoc();
462+
Torch::ValueTensorType resultType;
463+
llvm::SmallVector<Value> operands;
464+
if (binder.tensorOperandsList(operands) ||
465+
binder.tensorResultType(resultType))
466+
return failure();
467+
Value a = operands[0];
468+
Value aScale = operands[1];
469+
Value aZp = operands[2];
470+
Value b = operands[3];
471+
Value bScale = operands[4];
472+
Value bZp = operands[5];
473+
Value cScale = operands[6];
474+
475+
auto check = [](Value v) {
476+
auto vTy = cast<Torch::ValueTensorType>(v.getType());
477+
for (auto dim : vTy.getSizes())
478+
if (dim != 1)
479+
return false;
480+
return true;
481+
};
482+
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
483+
!check(cScale))
484+
return rewriter.notifyMatchFailure(
485+
binder.op, "Unsupported per-tensor quantization");
486+
487+
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
488+
binder.getLoc(),
489+
rewriter.getType<Torch::ListType>(
490+
rewriter.getType<Torch::IntType>()),
491+
ValueRange{});
492+
auto extract = [&rewriter, &binder, &emptyList](Value v) {
493+
auto vTy = cast<Torch::ValueTensorType>(v.getType());
494+
if (!vTy.getSizes().empty()) {
495+
vTy = rewriter.getType<Torch::ValueTensorType>(
496+
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
497+
v = rewriter.create<Torch::AtenReshapeOp>(binder.getLoc(), vTy, v,
498+
emptyList);
499+
}
500+
501+
Type extractTy = rewriter.getType<Torch::FloatType>();
502+
if (isa<IntegerType>(vTy.getDtype()))
503+
extractTy = rewriter.getType<Torch::IntType>();
504+
505+
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
506+
v);
507+
};
508+
509+
aZp = extract(aZp);
510+
bZp = extract(bZp);
511+
512+
Value cZp;
513+
if (operands.size() == 8) {
514+
cZp = operands[7];
515+
if (!check(cZp))
516+
return rewriter.notifyMatchFailure(
517+
binder.op,
518+
"Unsupported c_zero_point for per-tensor quantization");
519+
cZp = extract(cZp);
520+
} else {
521+
cZp = rewriter.create<Torch::ConstantIntOp>(
522+
loc, rewriter.getI64IntegerAttr(0));
523+
}
524+
525+
aScale = extract(aScale);
526+
bScale = extract(bScale);
527+
cScale = extract(cScale);
528+
529+
auto makePerTensor = [&rewriter, &binder](Value v, Value scale,
530+
Value zp) -> Value {
531+
auto ty = cast<Torch::ValueTensorType>(v.getType());
532+
auto newTy = getQTorchTypeFromTorchIntType(ty);
533+
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
534+
binder.getLoc(), newTy, v, scale, zp);
535+
};
536+
537+
a = makePerTensor(a, aScale, aZp);
538+
b = makePerTensor(b, bScale, bZp);
539+
540+
auto aTy = dyn_cast<Torch::ValueTensorType>(a.getType());
541+
if (!aTy || !aTy.hasSizes())
542+
return rewriter.notifyMatchFailure(
543+
binder.op, "Expected input argument `a` to have sizes");
544+
545+
aTy = rewriter.getType<Torch::ValueTensorType>(aTy.getOptionalSizes(),
546+
rewriter.getF32Type());
547+
a = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), aTy,
548+
a);
549+
550+
auto bTy = dyn_cast<Torch::ValueTensorType>(b.getType());
551+
if (!bTy || !bTy.hasSizes())
552+
return rewriter.notifyMatchFailure(
553+
binder.op, "Expected input argument `b` to have sizes");
554+
555+
bTy = rewriter.getType<Torch::ValueTensorType>(bTy.getOptionalSizes(),
556+
rewriter.getF32Type());
557+
b = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), bTy,
558+
b);
559+
560+
auto cTy = rewriter.getType<Torch::ValueTensorType>(
561+
resultType.getOptionalSizes(), rewriter.getF32Type());
562+
Value alpha = rewriter.create<Torch::ConstantFloatOp>(
563+
loc, rewriter.getF64FloatAttr(1.0));
564+
Value c = rewriter.create<Torch::AtenAddTensorOp>(binder.getLoc(), cTy,
565+
a, b, alpha);
566+
567+
cTy = dyn_cast<Torch::ValueTensorType>(
568+
getQTorchTypeFromTorchIntType(resultType));
569+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
570+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
571+
rewriter.getIntegerAttr(
572+
rewriter.getIntegerType(64),
573+
static_cast<int64_t>(
574+
Torch::getScalarTypeForType(cTy.getDtype()))));
575+
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
576+
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
577+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
578+
c);
579+
return success();
580+
});
458581
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3679,3 +3679,36 @@ func.func @test_rotary_embedding(%arg0: !torch.vtensor<[1,3,2,6],f32>, %arg1: !t
36793679
%4 = torch.operator "onnx.RotaryEmbedding"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,3,2,6],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[4,3],f32>, !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[1,3,2,6],f32>
36803680
return %4 : !torch.vtensor<[1,3,2,6],f32>
36813681
}
3682+
3683+
// -----
3684+
3685+
// CHECK-LABEL: @test_qlinearadd(
3686+
// CHECK-SAME: %[[A:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,4096],ui8>,
3687+
// CHECK-SAME: %[[A_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3688+
// CHECK-SAME: %[[A_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3689+
// CHECK-SAME: %[[B:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[4096],ui8>,
3690+
// CHECK-SAME: %[[B_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3691+
// CHECK-SAME: %[[B_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3692+
// CHECK-SAME: %[[C_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3693+
// CHECK-SAME: %[[C_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8>
3694+
func.func @test_qlinearadd(%arg0: !torch.vtensor<[1,4096],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[4096],ui8>, %arg4: !torch.vtensor<[],f32>, %arg5: !torch.vtensor<[],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8> attributes {torch.onnx_meta.opset_version = 10 : si64} {
3695+
%0 = torch.operator "onnx.QLinearAdd"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,4096],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[4096],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,4096],ui8>
3696+
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3697+
// CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[A_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3698+
// CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[B_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3699+
// CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[C_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3700+
// CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[A_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3701+
// CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[B_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3702+
// CHECK-DAG: %[[CSCALE:.+]] = torch.aten.item %[[C_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3703+
// CHECK-DAG: %[[A_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[A]], %[[ASCALE]], %[[AZP]] : !torch.vtensor<[1,4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3704+
// CHECK-DAG: %[[B_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[B]], %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4096],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4096],!torch.quint8>
3705+
// CHECK: %[[A_F32:.+]] = torch.aten.dequantize.self %[[A_QUANT]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],f32>
3706+
// CHECK: %[[B_F32:.+]] = torch.aten.dequantize.self %[[B_QUANT]] : !torch.vtensor<[4096],!torch.quint8> -> !torch.vtensor<[4096],f32>
3707+
// CHECK: %[[ALPHA:.+]] = torch.constant.float 1.000000e+00
3708+
// CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[A_F32]], %[[B_F32]], %[[ALPHA]] : !torch.vtensor<[1,4096],f32>, !torch.vtensor<[4096],f32>, !torch.float -> !torch.vtensor<[1,4096],f32>
3709+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3710+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[ADD]], %[[CSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[1,4096],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,4096],!torch.quint8>
3711+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,4096],!torch.quint8> -> !torch.vtensor<[1,4096],ui8>
3712+
// CHECK: return %[[OUT]]
3713+
return %0 : !torch.vtensor<[1,4096],ui8>
3714+
}

0 commit comments

Comments
 (0)