Skip to content

Commit b5aa39f

Browse files
[ONNX] Add support for Onnx.QLinearGlobalAveragePool op (#4120)
This commit adds the Onnx->Torch lowering for Onnx.[QLinearGlobalAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqlinearglobalaveragepool) op.
1 parent 9e5531e commit b5aa39f

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,4 +748,126 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
748748
result);
749749
return success();
750750
});
751+
patterns.onOp(
752+
"QLinearGlobalAveragePool", 1,
753+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
754+
Location loc = binder.getLoc();
755+
Torch::ValueTensorType resultType;
756+
llvm::SmallVector<Value> operands;
757+
int64_t channelsLast;
758+
if (binder.tensorOperands(operands, 5) ||
759+
binder.tensorResultType(resultType) ||
760+
binder.s64IntegerAttr(channelsLast, "channels_last"))
761+
return failure();
762+
763+
Value x = operands[0];
764+
Value xScale, xZp, yScale, yZp;
765+
766+
if (failed(extractPerTensorQuantizationArguments(
767+
rewriter, loc, /*scale=*/operands[1],
768+
/*zero_point=*/operands[2], xScale, xZp)))
769+
return rewriter.notifyMatchFailure(
770+
binder.op, "Incompatible arguments for per-tensor quantization");
771+
772+
if (failed(extractPerTensorQuantizationArguments(
773+
rewriter, loc, /*scale=*/operands[3],
774+
/*zero_point=*/operands[4], yScale, yZp)))
775+
return rewriter.notifyMatchFailure(
776+
binder.op, "Incompatible arguments for per-tensor quantization");
777+
778+
auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType());
779+
if (!xTy || !xTy.hasSizes())
780+
return rewriter.notifyMatchFailure(
781+
binder.op, "Expected input argument `x` to have sizes");
782+
ArrayRef<int64_t> inputShape = xTy.getSizes();
783+
784+
xTy = getQTorchTypeFromTorchIntType(xTy);
785+
x = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
786+
loc, xTy, x, xScale, xZp);
787+
xTy = rewriter.getType<Torch::ValueTensorType>(inputShape,
788+
rewriter.getF32Type());
789+
// Dequantizing the input tensor `x`.
790+
x = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, xTy, x);
791+
792+
if (!resultType || !resultType.hasSizes()) {
793+
return rewriter.notifyMatchFailure(
794+
binder.op, "Expected result type having sizes");
795+
}
796+
ArrayRef<int64_t> resultShape = resultType.getSizes();
797+
798+
// Computing the AvgPool result.
799+
SmallVector<Value> cstKernel, cstPadding, cstStrides;
800+
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
801+
loc, rewriter.getI64IntegerAttr(0));
802+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
803+
loc, rewriter.getI64IntegerAttr(1));
804+
unsigned inputRank = inputShape.size();
805+
for (unsigned i = 2; i < inputRank; i++) {
806+
if (inputShape[i] == Torch::kUnknownSize) {
807+
Value dim = rewriter.create<Torch::ConstantIntOp>(
808+
loc, rewriter.getI64IntegerAttr(i));
809+
Value inputDimSize =
810+
rewriter.create<Torch::AtenSizeIntOp>(loc, x, dim);
811+
cstKernel.push_back(inputDimSize);
812+
} else {
813+
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
814+
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
815+
loc, rewriter.getI64IntegerAttr(kernelSize)));
816+
}
817+
cstPadding.push_back(cstZero);
818+
cstStrides.push_back(cstOne);
819+
}
820+
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
821+
loc,
822+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
823+
cstKernel);
824+
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
825+
loc,
826+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
827+
cstPadding);
828+
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
829+
loc,
830+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
831+
cstStrides);
832+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
833+
Value cstCeilMode = cstFalse;
834+
Value cstCountIncludePad = cstFalse;
835+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
836+
837+
auto yTy = rewriter.getType<Torch::ValueTensorType>(
838+
resultShape, rewriter.getF32Type());
839+
Value avgpool;
840+
if (inputRank == 3) {
841+
avgpool = rewriter.create<Torch::AtenAvgPool1dOp>(
842+
loc, yTy, x, kernelSizeList, stridesList, paddingList,
843+
cstCeilMode, cstCountIncludePad);
844+
} else if (inputRank == 4) {
845+
avgpool = rewriter.create<Torch::AtenAvgPool2dOp>(
846+
loc, yTy, x, kernelSizeList, stridesList, paddingList,
847+
cstCeilMode, cstCountIncludePad,
848+
/*divisor_override=*/cstNone);
849+
} else if (inputRank == 5) {
850+
avgpool = rewriter.create<Torch::AtenAvgPool3dOp>(
851+
loc, yTy, x, kernelSizeList, stridesList, paddingList,
852+
cstCeilMode, cstCountIncludePad,
853+
/*divisor_override=*/cstNone);
854+
} else {
855+
return failure();
856+
}
857+
858+
// Quantizing the result of AvgPool op.
859+
yTy = dyn_cast<Torch::ValueTensorType>(
860+
getQTorchTypeFromTorchIntType(resultType));
861+
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
862+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
863+
rewriter.getIntegerAttr(
864+
rewriter.getIntegerType(64),
865+
static_cast<int64_t>(
866+
Torch::getScalarTypeForType(yTy.getDtype()))));
867+
avgpool = rewriter.create<Torch::AtenQuantizePerTensorOp>(
868+
loc, yTy, avgpool, yScale, yZp, dtyVal);
869+
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
870+
avgpool);
871+
return success();
872+
});
751873
}

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3763,3 +3763,37 @@ func.func @test_qlinearconcat(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtens
37633763
// CHECK: return %[[OUT]]
37643764
return %0 : !torch.vtensor<[?,?,?,?],ui8>
37653765
}
3766+
3767+
// -----
3768+
3769+
// CHECK-LABEL: @test_qlinearglobalavgpool(
3770+
// CHECK-SAME: %[[X:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,1000,13,13],ui8>,
3771+
// CHECK-SAME: %[[X_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3772+
// CHECK-SAME: %[[X_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>,
3773+
// CHECK-SAME: %[[Y_SCALE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],f32>,
3774+
// CHECK-SAME: %[[Y_ZERO_POINT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1000,1,1],ui8>
3775+
func.func @test_qlinearglobalavgpool(%arg0: !torch.vtensor<[1,1000,13,13],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1000,1,1],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
3776+
%0 = torch.operator "onnx.QLinearGlobalAveragePool"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.channels_last = 0 : si64} : (!torch.vtensor<[1,1000,13,13],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1000,1,1],ui8>
3777+
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
3778+
// CHECK-DAG: %[[XSCALE:.+]] = torch.aten.item %[[X_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3779+
// CHECK-DAG: %[[YSCALE:.+]] = torch.aten.item %[[Y_SCALE]] : !torch.vtensor<[],f32> -> !torch.float
3780+
// CHECK-DAG: %[[XZP:.+]] = torch.aten.item %[[X_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3781+
// CHECK-DAG: %[[YZP:.+]] = torch.aten.item %[[Y_ZERO_POINT]] : !torch.vtensor<[],ui8> -> !torch.int
3782+
// CHECK-DAG: %[[X_QUANT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[X]], %[[XSCALE]], %[[XZP]] : !torch.vtensor<[1,1000,13,13],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1000,13,13],!torch.quint8>
3783+
// CHECK: %[[X_F32:.+]] = torch.aten.dequantize.self %[[X_QUANT]] : !torch.vtensor<[1,1000,13,13],!torch.quint8> -> !torch.vtensor<[1,1000,13,13],f32>
3784+
// CHECK: %[[C0:.*]] = torch.constant.int 0
3785+
// CHECK: %[[C1:.*]] = torch.constant.int 1
3786+
// CHECK: %[[C13:.*]] = torch.constant.int 13
3787+
// CHECK: %[[C13_0:.*]] = torch.constant.int 13
3788+
// CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C13]], %[[C13_0]] : (!torch.int, !torch.int) -> !torch.list<int>
3789+
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
3790+
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
3791+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
3792+
// CHECK: %[[NONE:.*]] = torch.constant.none
3793+
// CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[X_F32]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1000,13,13],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1000,1,1],f32>
3794+
// CHECK: %[[DTY:.+]] = torch.constant.int 13
3795+
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[AVGPOOL]], %[[YSCALE]], %[[YZP]], %[[DTY]] : !torch.vtensor<[1,1000,1,1],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1000,1,1],!torch.quint8>
3796+
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[1,1000,1,1],!torch.quint8> -> !torch.vtensor<[1,1000,1,1],ui8>
3797+
// CHECK: return %[[OUT]]
3798+
return %0 : !torch.vtensor<[1,1000,1,1],ui8>
3799+
}

0 commit comments

Comments
 (0)