Skip to content

Commit 5e5a74a

Browse files
[ONNX] Add support for Onnx.QLinearLeakyRelu op (#4115)
This commit adds the Onnx->Torch lowering for [Onnx.QLinearLeakyRelu](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearLeakyRelu) op. Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent cfb213e commit 5e5a74a

File tree

4 files changed

+140
-0
lines changed

4 files changed

+140
-0
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ LogicalResult createTorchPermuteOp(OpBinder binder,
119119
SmallVector<int64_t> permuteDims,
120120
Value &permuted);
121121

122+
/// This utility checks the compatibility for scale and zero_point value and
123+
/// extracts the scalar value from it used for per-tensor quantization.
124+
LogicalResult extractPerTensorQuantizationArguments(
125+
ConversionPatternRewriter &rewriter, Location loc, Value inScale,
126+
Value inZeroPoint, Value &outScale, Value &outZeroPoint);
127+
122128
} // namespace mlir::torch::onnx_c
123129

124130
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,4 +578,71 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
578578
c);
579579
return success();
580580
});
581+
patterns.onOp(
582+
"QLinearLeakyRelu", 1,
583+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
584+
Location loc = binder.getLoc();
585+
Torch::ValueTensorType resultType;
586+
llvm::SmallVector<Value> operands;
587+
float alpha;
588+
if (binder.tensorOperandsList(operands) ||
589+
binder.tensorResultType(resultType) ||
590+
binder.f32FloatAttr(alpha, "alpha"))
591+
return failure();
592+
593+
Value x = operands[0];
594+
Value xScale = operands[1];
595+
Value xZp = operands[2];
596+
Value yScale = operands[3];
597+
Value yZp = operands[4];
598+
599+
if (failed(extractPerTensorQuantizationArguments(
600+
rewriter, loc, /*scale=*/operands[1],
601+
/*zero_point=*/operands[2], xScale, xZp)))
602+
return rewriter.notifyMatchFailure(
603+
binder.op, "Incompatible arguments for per-tensor quantization");
604+
605+
if (failed(extractPerTensorQuantizationArguments(
606+
rewriter, loc, /*scale=*/operands[3],
607+
/*zero_point=*/operands[4], yScale, yZp)))
608+
return rewriter.notifyMatchFailure(
609+
binder.op, "Incompatible arguments for per-tensor quantization");
610+
611+
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
612+
if (!xTy || !xTy.hasSizes())
613+
return rewriter.notifyMatchFailure(
614+
binder.op, "Expected input argument `x` to have sizes");
615+
616+
xTy = getQTorchTypeFromTorchIntType(xTy);
617+
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
618+
loc, xTy, x, xScale, xZp);
619+
xTy = rewriter.getType<Torch::ValueTensorType>(xTy.getSizes(),
620+
rewriter.getF32Type());
621+
// Dequantizing the input tensor `x`.
622+
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
623+
624+
// Computing the LeakyRelu result.
625+
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
626+
loc, rewriter.getType<Torch::FloatType>(),
627+
rewriter.getF64FloatAttr((double)alpha));
628+
auto yTy = rewriter.getType<Torch::ValueTensorType>(
629+
resultType.getOptionalSizes(), rewriter.getF32Type());
630+
Value y =
631+
rewriter.create<Torch::AtenLeakyReluOp>(loc, yTy, x, constAlpha);
632+
633+
// Quantizing the result of LeakyRelu op.
634+
yTy = dyn_cast<Torch::ValueTensorType>(
635+
getQTorchTypeFromTorchIntType(resultType));
636+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
637+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
638+
rewriter.getIntegerAttr(
639+
rewriter.getIntegerType(64),
640+
static_cast<int64_t>(
641+
Torch::getScalarTypeForType(yTy.getDtype()))));
642+
y = rewriter.create<Torch::AtenQuantizePerTensorOp>(loc, yTy, y, yScale,
643+
yZp, dtyVal);
644+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
645+
y);
646+
return success();
647+
});
581648
}

lib/Conversion/TorchOnnxToTorch/Utils.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,43 @@ Value mlir::torch::onnx_c::createActivationByName(ImplicitLocOpBuilder &b,
142142
return b.create<Torch::AtenReluOp>(input.getType(), input);
143143
llvm_unreachable("Unsupported activation function");
144144
}
145+
146+
LogicalResult mlir::torch::onnx_c::extractPerTensorQuantizationArguments(
147+
ConversionPatternRewriter &rewriter, Location loc, Value inScale,
148+
Value inZeroPoint, Value &outScale, Value &outZeroPoint) {
149+
150+
auto check = [](Value v) {
151+
auto vTy = cast<Torch::ValueTensorType>(v.getType());
152+
for (auto dim : vTy.getSizes())
153+
if (dim != 1)
154+
return false;
155+
return true;
156+
};
157+
158+
if (!check(inScale) || !check(inZeroPoint))
159+
return failure();
160+
161+
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
162+
loc,
163+
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
164+
ValueRange{});
165+
auto extract = [&rewriter, &loc, &emptyList](Value v) {
166+
auto vTy = cast<Torch::ValueTensorType>(v.getType());
167+
if (!vTy.getSizes().empty()) {
168+
vTy = rewriter.getType<Torch::ValueTensorType>(ArrayRef<int64_t>({}),
169+
vTy.getOptionalDtype());
170+
v = rewriter.create<Torch::AtenReshapeOp>(loc, vTy, v, emptyList);
171+
}
172+
173+
Type extractTy = rewriter.getType<Torch::FloatType>();
174+
if (isa<IntegerType>(vTy.getDtype()))
175+
extractTy = rewriter.getType<Torch::IntType>();
176+
177+
return rewriter.create<Torch::AtenItemOp>(loc, extractTy, v);
178+
};
179+
180+
outScale = extract(inScale);
181+
outZeroPoint = extract(inZeroPoint);
182+
183+
return success();
184+
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3712,3 +3712,30 @@ func.func @test_qlinearadd(%arg0: !torch.vtensor<[1,4096],ui8>, %arg1: !torch.vt
37123712
// CHECK: return %[[OUT]]
37133713
return %0 : !torch.vtensor<[1,4096],ui8>
37143714
}
3715+
3716+
// -----
3717+
3718+
// CHECK-LABEL: @test_qlinearleakyrelu(
3719+
// CHECK-SAME: %[[X:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,32,?,?],ui8>,
3720+
// CHECK-SAME: %[[X_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3721+
// CHECK-SAME: %[[X_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3722+
// CHECK-SAME: %[[Y_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3723+
// CHECK-SAME: %[[Y_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,32,?,?],ui8>
3724+
func.func @test_qlinearleakyrelu(%arg0: !torch.vtensor<[?,32,?,?],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,32,?,?],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
3725+
%0 = torch.operator "onnx.QLinearLeakyRelu"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.alpha = 1.000000e-01 : f32} : (!torch.vtensor<[?,32,?,?],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[?,32,?,?],ui8>
3726+
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3727+
// CHECK-DAG: %[[XSCALE:.+]] = torch.aten.item %[[X_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3728+
// CHECK-DAG: %[[XZP:.+]] = torch.aten.item %[[X_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3729+
// CHECK-DAG: %[[EMPTY_0:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3730+
// CHECK-DAG: %[[YSCALE:.+]] = torch.aten.item %[[Y_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3731+
// CHECK-DAG: %[[YZP:.+]] = torch.aten.item %[[Y_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3732+
// CHECK-DAG: %[[X_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[X]], %[[XSCALE]], %[[XZP]] : !torch.vtensor<[?,32,?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,32,?,?],!torch.quint8>
3733+
// CHECK: %[[X_F32:.+]] = torch.aten.dequantize.self %[[X_QUANT]] : !torch.vtensor<[?,32,?,?],!torch.quint8> -> !torch.vtensor<[?,32,?,?],f32>
3734+
// CHECK: %[[ALPHA:.+]] = torch.constant.float 0.10000000149011612
3735+
// CHECK: %[[LEAKYRELU:.+]] = torch.aten.leaky_relu %[[X_F32]], %[[ALPHA]] : !torch.vtensor<[?,32,?,?],f32>, !torch.float -> !torch.vtensor<[?,32,?,?],f32>
3736+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3737+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[LEAKYRELU]], %[[YSCALE]], %[[YZP]], %[[DTY]] : !torch.vtensor<[?,32,?,?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?,32,?,?],!torch.quint8>
3738+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[?,32,?,?],!torch.quint8> -> !torch.vtensor<[?,32,?,?],ui8>
3739+
// CHECK: return %[[OUT]]
3740+
return %0 : !torch.vtensor<[?,32,?,?],ui8>
3741+
}

0 commit comments

Comments
 (0)