@@ -5720,6 +5720,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
57205720 });
57215721}
57225722
5723+ namespace {
5724+
5725+ void expand (SmallVectorImpl<int64_t > ¶ms, int numSpatialDims) {
5726+ if (params.size () == 1 ) {
5727+ for ([[maybe_unused]] int dim : llvm::seq<int >(0 , numSpatialDims - 1 )) {
5728+ params.push_back (params[0 ]);
5729+ }
5730+ }
5731+ }
5732+
5733+ template <typename AtenPoolOpT>
5734+ LogicalResult expandPoolParams (AtenPoolOpT op, int numSpatialDims,
5735+ mlir::PatternRewriter &rewriter,
5736+ Value &kernelSizeList, Value &stridesList,
5737+ Value &paddingList, Value &dilationsList) {
5738+
5739+ SmallVector<int64_t , 3 > kernelSizeInts, strideInts, paddingInts, dilationInts;
5740+ if (!matchPattern (op.getKernelSize (),
5741+ m_TorchListOfConstantInts (kernelSizeInts)))
5742+ return rewriter.notifyMatchFailure (
5743+ op, " Non-const kernel_size for pooling op unsupported" );
5744+
5745+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
5746+ return rewriter.notifyMatchFailure (
5747+ op, " Non-const padding factor for pooling op unsupported" );
5748+
5749+ if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
5750+ return rewriter.notifyMatchFailure (
5751+ op, " Non-const stride for pooling op unsupported" );
5752+
5753+ if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5754+ std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5755+ if (!matchPattern (op.getDilation (),
5756+ m_TorchListOfConstantInts (dilationInts)))
5757+ return rewriter.notifyMatchFailure (
5758+ op, " Non-const dilation for pooling op unsupported" );
5759+
5760+ if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5761+ strideInts.size () != 1 && dilationInts.size () != 1 ) {
5762+ return rewriter.notifyMatchFailure (
5763+ op,
5764+ " Expected one of kernel/stride/padding/dilation to be singleton." );
5765+ }
5766+
5767+ expand (dilationInts, numSpatialDims);
5768+
5769+ } else if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5770+ strideInts.size () != 1 ) {
5771+ return rewriter.notifyMatchFailure (
5772+ op, " Expected one of kernel/stride/padding to be singleton." );
5773+ }
5774+
5775+ // expand singleton elements
5776+ expand (kernelSizeInts, numSpatialDims);
5777+ expand (paddingInts, numSpatialDims);
5778+ expand (strideInts, numSpatialDims);
5779+
5780+ Location loc = op.getLoc ();
5781+
5782+ SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5783+ for (auto dim : llvm::seq<int >(0 , kernelSizeInts.size ())) {
5784+ cstKernel.push_back (Torch::ConstantIntOp::create (
5785+ rewriter, loc, rewriter.getI64IntegerAttr (kernelSizeInts[dim])));
5786+ cstPadding.push_back (Torch::ConstantIntOp::create (
5787+ rewriter, loc, rewriter.getI64IntegerAttr (paddingInts[dim])));
5788+ cstStrides.push_back (Torch::ConstantIntOp::create (
5789+ rewriter, loc, rewriter.getI64IntegerAttr (strideInts[dim])));
5790+ }
5791+
5792+ // set dilations separately as for AvgPool op it won't be set
5793+ for (auto dim : llvm::seq<int >(0 , dilationInts.size ())) {
5794+ cstDilations.push_back (Torch::ConstantIntOp::create (
5795+ rewriter, loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
5796+ }
5797+
5798+ auto targetListType =
5799+ Torch::ListType::get (Torch::IntType::get (op->getContext ()));
5800+ kernelSizeList = Torch::PrimListConstructOp::create (
5801+ rewriter, loc, targetListType, cstKernel);
5802+ paddingList = Torch::PrimListConstructOp::create (rewriter, loc,
5803+ targetListType, cstPadding);
5804+ stridesList = Torch::PrimListConstructOp::create (rewriter, loc,
5805+ targetListType, cstStrides);
5806+ dilationsList = Torch::PrimListConstructOp::create (
5807+ rewriter, loc, targetListType, cstDilations);
5808+
5809+ return success ();
5810+ }
5811+
5812+ template <typename AvgPoolOpT>
5813+ struct CanonicalizeAvgPoolWithSingleIntTuple
5814+ : public mlir::OpRewritePattern<AvgPoolOpT> {
5815+ CanonicalizeAvgPoolWithSingleIntTuple (mlir::MLIRContext *context)
5816+ : OpRewritePattern<AvgPoolOpT>(context, /* benefit=*/ 1 ) {}
5817+
5818+ LogicalResult
5819+ matchAndRewrite (AvgPoolOpT op,
5820+ mlir::PatternRewriter &rewriter) const override {
5821+ Value kernel, stride, pad, dilations;
5822+
5823+ auto numSpatialDims = 2 ;
5824+ if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5825+ numSpatialDims = 3 ;
5826+
5827+ // Attempt to expand params if necessary.
5828+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5829+ pad, dilations)))
5830+ return rewriter.notifyMatchFailure (
5831+ op, " Failed to expand params for AvgPooling" );
5832+
5833+ rewriter.replaceOpWithNewOp <AvgPoolOpT>(
5834+ op, op.getResult ().getType (), op.getSelf (), kernel, stride, pad,
5835+ op.getCeilMode (), op.getCountIncludePad (), op.getDivisorOverride ());
5836+ return success ();
5837+ }
5838+ };
5839+
5840+ template <typename MaxPoolOpT>
5841+ struct CanonicalizeMaxPoolWithSingleIntTuple
5842+ : public mlir::OpRewritePattern<MaxPoolOpT> {
5843+ CanonicalizeMaxPoolWithSingleIntTuple (mlir::MLIRContext *context)
5844+ : OpRewritePattern<MaxPoolOpT>(context, /* benefit=*/ 1 ) {}
5845+
5846+ LogicalResult
5847+ matchAndRewrite (MaxPoolOpT op,
5848+ mlir::PatternRewriter &rewriter) const override {
5849+ Value kernel, stride, pad, dilations;
5850+
5851+ auto numSpatialDims = 2 ;
5852+ if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5853+ numSpatialDims = 3 ;
5854+
5855+ // Attempt to expand params if necessary.
5856+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5857+ pad, dilations)))
5858+ return rewriter.notifyMatchFailure (
5859+ op, " Failed to expand params for MaxPooling" );
5860+
5861+ rewriter.replaceOpWithNewOp <MaxPoolOpT>(op, op.getResult ().getType (),
5862+ op.getSelf (), kernel, stride, pad,
5863+ dilations, op.getCeilMode ());
5864+ return success ();
5865+ }
5866+ };
5867+ } // namespace
5868+
5869+ // ===----------------------------------------------------------------------===//
5870+ // AtenAvgPool2dOp
5871+ // ===----------------------------------------------------------------------===//
5872+ void AtenAvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5873+ MLIRContext *context) {
5874+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5875+ }
5876+
5877+ // ===----------------------------------------------------------------------===//
5878+ // AtenAvgPool3dOp
5879+ // ===----------------------------------------------------------------------===//
5880+ void AtenAvgPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5881+ MLIRContext *context) {
5882+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5883+ }
5884+
5885+ // ===----------------------------------------------------------------------===//
5886+ // AtenMaxPool2dOp
5887+ // ===----------------------------------------------------------------------===//
5888+ void AtenMaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5889+ MLIRContext *context) {
5890+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5891+ }
5892+
5893+ // ===----------------------------------------------------------------------===//
5894+ // AtenMaxPool3dOp
5895+ // ===----------------------------------------------------------------------===//
5896+ void AtenMaxPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5897+ MLIRContext *context) {
5898+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5899+ }
5900+
57235901// ===----------------------------------------------------------------------===//
57245902// AtenLinalgCrossOp
57255903// ===----------------------------------------------------------------------===//
0 commit comments