@@ -748,4 +748,126 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain(
748748 result);
749749 return success ();
750750 });
751+ patterns.onOp (
752+ " QLinearGlobalAveragePool" , 1 ,
753+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
754+ Location loc = binder.getLoc ();
755+ Torch::ValueTensorType resultType;
756+ llvm::SmallVector<Value> operands;
757+ int64_t channelsLast;
758+ if (binder.tensorOperands (operands, 5 ) ||
759+ binder.tensorResultType (resultType) ||
760+ binder.s64IntegerAttr (channelsLast, " channels_last" ))
761+ return failure ();
762+
763+ Value x = operands[0 ];
764+ Value xScale, xZp, yScale, yZp;
765+
766+ if (failed (extractPerTensorQuantizationArguments (
767+ rewriter, loc, /* scale=*/ operands[1 ],
768+ /* zero_point=*/ operands[2 ], xScale, xZp)))
769+ return rewriter.notifyMatchFailure (
770+ binder.op , " Incompatible arguments for per-tensor quantization" );
771+
772+ if (failed (extractPerTensorQuantizationArguments (
773+ rewriter, loc, /* scale=*/ operands[3 ],
774+ /* zero_point=*/ operands[4 ], yScale, yZp)))
775+ return rewriter.notifyMatchFailure (
776+ binder.op , " Incompatible arguments for per-tensor quantization" );
777+
778+ auto xTy = dyn_cast<Torch::ValueTensorType>(x.getType ());
779+ if (!xTy || !xTy.hasSizes ())
780+ return rewriter.notifyMatchFailure (
781+ binder.op , " Expected input argument `x` to have sizes" );
782+ ArrayRef<int64_t > inputShape = xTy.getSizes ();
783+
784+ xTy = getQTorchTypeFromTorchIntType (xTy);
785+ x = rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
786+ loc, xTy, x, xScale, xZp);
787+ xTy = rewriter.getType <Torch::ValueTensorType>(inputShape,
788+ rewriter.getF32Type ());
789+ // Dequantizing the input tensor `x`.
790+ x = rewriter.create <Torch::AtenDequantizeSelfOp>(loc, xTy, x);
791+
792+ if (!resultType || !resultType.hasSizes ()) {
793+ return rewriter.notifyMatchFailure (
794+ binder.op , " Expected result type having sizes" );
795+ }
796+ ArrayRef<int64_t > resultShape = resultType.getSizes ();
797+
798+ // Computing the AvgPool result.
799+ SmallVector<Value> cstKernel, cstPadding, cstStrides;
800+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
801+ loc, rewriter.getI64IntegerAttr (0 ));
802+ Value cstOne = rewriter.create <Torch::ConstantIntOp>(
803+ loc, rewriter.getI64IntegerAttr (1 ));
804+ unsigned inputRank = inputShape.size ();
805+ for (unsigned i = 2 ; i < inputRank; i++) {
806+ if (inputShape[i] == Torch::kUnknownSize ) {
807+ Value dim = rewriter.create <Torch::ConstantIntOp>(
808+ loc, rewriter.getI64IntegerAttr (i));
809+ Value inputDimSize =
810+ rewriter.create <Torch::AtenSizeIntOp>(loc, x, dim);
811+ cstKernel.push_back (inputDimSize);
812+ } else {
813+ int64_t kernelSize = inputShape[i] - resultShape[i] + 1 ;
814+ cstKernel.push_back (rewriter.create <Torch::ConstantIntOp>(
815+ loc, rewriter.getI64IntegerAttr (kernelSize)));
816+ }
817+ cstPadding.push_back (cstZero);
818+ cstStrides.push_back (cstOne);
819+ }
820+ Value kernelSizeList = rewriter.create <Torch::PrimListConstructOp>(
821+ loc,
822+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
823+ cstKernel);
824+ Value paddingList = rewriter.create <Torch::PrimListConstructOp>(
825+ loc,
826+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
827+ cstPadding);
828+ Value stridesList = rewriter.create <Torch::PrimListConstructOp>(
829+ loc,
830+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
831+ cstStrides);
832+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
833+ Value cstCeilMode = cstFalse;
834+ Value cstCountIncludePad = cstFalse;
835+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc);
836+
837+ auto yTy = rewriter.getType <Torch::ValueTensorType>(
838+ resultShape, rewriter.getF32Type ());
839+ Value avgpool;
840+ if (inputRank == 3 ) {
841+ avgpool = rewriter.create <Torch::AtenAvgPool1dOp>(
842+ loc, yTy, x, kernelSizeList, stridesList, paddingList,
843+ cstCeilMode, cstCountIncludePad);
844+ } else if (inputRank == 4 ) {
845+ avgpool = rewriter.create <Torch::AtenAvgPool2dOp>(
846+ loc, yTy, x, kernelSizeList, stridesList, paddingList,
847+ cstCeilMode, cstCountIncludePad,
848+ /* divisor_override=*/ cstNone);
849+ } else if (inputRank == 5 ) {
850+ avgpool = rewriter.create <Torch::AtenAvgPool3dOp>(
851+ loc, yTy, x, kernelSizeList, stridesList, paddingList,
852+ cstCeilMode, cstCountIncludePad,
853+ /* divisor_override=*/ cstNone);
854+ } else {
855+ return failure ();
856+ }
857+
858+ // Quantizing the result of AvgPool op.
859+ yTy = dyn_cast<Torch::ValueTensorType>(
860+ getQTorchTypeFromTorchIntType (resultType));
861+ Value dtyVal = rewriter.create <Torch::ConstantIntOp>(
862+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
863+ rewriter.getIntegerAttr (
864+ rewriter.getIntegerType (64 ),
865+ static_cast <int64_t >(
866+ Torch::getScalarTypeForType (yTy.getDtype ()))));
867+ avgpool = rewriter.create <Torch::AtenQuantizePerTensorOp>(
868+ loc, yTy, avgpool, yScale, yZp, dtyVal);
869+ rewriter.replaceOpWithNewOp <Torch::AtenIntReprOp>(binder.op , resultType,
870+ avgpool);
871+ return success ();
872+ });
751873}
0 commit comments