@@ -5617,6 +5617,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
5617
5617
});
5618
5618
}
5619
5619
5620
+ namespace {
5621
+
5622
+ void expand (SmallVectorImpl<int64_t > ¶ms, int numSpatialDims) {
5623
+ if (params.size () == 1 ) {
5624
+ for (auto _ : llvm::seq<int >(0 , numSpatialDims - 1 )) {
5625
+ params.push_back (params[0 ]);
5626
+ }
5627
+ }
5628
+ }
5629
+
5630
+ template <typename AtenPoolOpT>
5631
+ LogicalResult expandPoolParams (AtenPoolOpT op, int numSpatialDims,
5632
+ mlir::PatternRewriter &rewriter,
5633
+ Value &kernelSizeList, Value &stridesList,
5634
+ Value &paddingList, Value &dilationsList) {
5635
+
5636
+ SmallVector<int64_t , 3 > kernelSizeInts, strideInts, paddingInts, dilationInts;
5637
+ if (!matchPattern (op.getKernelSize (),
5638
+ m_TorchListOfConstantInts (kernelSizeInts)))
5639
+ return rewriter.notifyMatchFailure (
5640
+ op, " Non-const kernel_size for pooling op unsupported" );
5641
+
5642
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
5643
+ return rewriter.notifyMatchFailure (
5644
+ op, " Non-const padding factor for pooling op unsupported" );
5645
+
5646
+ if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
5647
+ return rewriter.notifyMatchFailure (
5648
+ op, " Non-const stride for pooling op unsupported" );
5649
+
5650
+ if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5651
+ std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5652
+ if (!matchPattern (op.getDilation (),
5653
+ m_TorchListOfConstantInts (dilationInts)))
5654
+ return rewriter.notifyMatchFailure (
5655
+ op, " Non-const dilation for pooling op unsupported" );
5656
+
5657
+ if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5658
+ strideInts.size () != 1 && dilationInts.size () != 1 ) {
5659
+ return rewriter.notifyMatchFailure (
5660
+ op,
5661
+ " Expected one of kernel/stride/padding/dilation to be singleton." );
5662
+ }
5663
+
5664
+ expand (dilationInts, numSpatialDims);
5665
+
5666
+ } else if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5667
+ strideInts.size () != 1 ) {
5668
+ return rewriter.notifyMatchFailure (
5669
+ op, " Expected one of kernel/stride/padding to be singleton." );
5670
+ }
5671
+
5672
+ // expand singleton elements
5673
+ expand (kernelSizeInts, numSpatialDims);
5674
+ expand (paddingInts, numSpatialDims);
5675
+ expand (strideInts, numSpatialDims);
5676
+
5677
+ Location loc = op.getLoc ();
5678
+
5679
+ SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5680
+ for (auto dim : llvm::seq<int >(0 , kernelSizeInts.size ())) {
5681
+ cstKernel.push_back (rewriter.create <Torch::ConstantIntOp>(
5682
+ loc, rewriter.getI64IntegerAttr (kernelSizeInts[dim])));
5683
+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
5684
+ loc, rewriter.getI64IntegerAttr (paddingInts[dim])));
5685
+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
5686
+ loc, rewriter.getI64IntegerAttr (strideInts[dim])));
5687
+ }
5688
+
5689
+ // set dilations separately as for AvgPool op it won't be set
5690
+ for (auto dim : llvm::seq<int >(0 , dilationInts.size ())) {
5691
+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
5692
+ loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
5693
+ }
5694
+
5695
+ auto targetListType =
5696
+ Torch::ListType::get (Torch::IntType::get (op->getContext ()));
5697
+ kernelSizeList = rewriter.create <Torch::PrimListConstructOp>(
5698
+ loc, targetListType, cstKernel);
5699
+ paddingList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5700
+ cstPadding);
5701
+ stridesList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5702
+ cstStrides);
5703
+ dilationsList = rewriter.create <Torch::PrimListConstructOp>(
5704
+ loc, targetListType, cstDilations);
5705
+
5706
+ return success ();
5707
+ }
5708
+
5709
+ template <typename AvgPoolOpT>
5710
+ struct CanonicalizeAvgPoolWithSingleIntTuple
5711
+ : public mlir::OpRewritePattern<AvgPoolOpT> {
5712
+ CanonicalizeAvgPoolWithSingleIntTuple (mlir::MLIRContext *context)
5713
+ : OpRewritePattern<AvgPoolOpT>(context, /* benefit=*/ 1 ) {}
5714
+
5715
+ LogicalResult
5716
+ matchAndRewrite (AvgPoolOpT op,
5717
+ mlir::PatternRewriter &rewriter) const override {
5718
+ Value kernel, stride, pad, dilations;
5719
+
5720
+ auto numSpatialDims = 2 ;
5721
+ if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5722
+ numSpatialDims = 3 ;
5723
+
5724
+ // Attempt to expand params if necessary.
5725
+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5726
+ pad, dilations)))
5727
+ return rewriter.notifyMatchFailure (op,
5728
+ " Failed to expand params for pooling" );
5729
+
5730
+ rewriter.replaceOpWithNewOp <AvgPoolOpT>(
5731
+ op, op.getResult ().getType (), op.getSelf (), kernel, stride, pad,
5732
+ op.getCeilMode (), op.getCountIncludePad (), op.getDivisorOverride ());
5733
+ return success ();
5734
+ }
5735
+ };
5736
+
5737
+ template <typename MaxPoolOpT>
5738
+ struct CanonicalizeMaxPoolWithSingleIntTuple
5739
+ : public mlir::OpRewritePattern<MaxPoolOpT> {
5740
+ CanonicalizeMaxPoolWithSingleIntTuple (mlir::MLIRContext *context)
5741
+ : OpRewritePattern<MaxPoolOpT>(context, /* benefit=*/ 1 ) {}
5742
+
5743
+ LogicalResult
5744
+ matchAndRewrite (MaxPoolOpT op,
5745
+ mlir::PatternRewriter &rewriter) const override {
5746
+ Value kernel, stride, pad, dilations;
5747
+
5748
+ auto numSpatialDims = 2 ;
5749
+ if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5750
+ numSpatialDims = 3 ;
5751
+
5752
+ // Attempt to expand params if necessary.
5753
+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5754
+ pad, dilations)))
5755
+ return rewriter.notifyMatchFailure (op,
5756
+ " Failed to expand params for pooling" );
5757
+
5758
+ rewriter.replaceOpWithNewOp <MaxPoolOpT>(op, op.getResult ().getType (),
5759
+ op.getSelf (), kernel, stride, pad,
5760
+ dilations, op.getCeilMode ());
5761
+ return success ();
5762
+ }
5763
+ };
5764
+ } // namespace
5765
+
5766
+ // ===----------------------------------------------------------------------===//
5767
+ // AtenAvgPool2dOp
5768
+ // ===----------------------------------------------------------------------===//
5769
+ void AtenAvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5770
+ MLIRContext *context) {
5771
+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5772
+ }
5773
+
5774
+ // ===----------------------------------------------------------------------===//
5775
+ // AtenAvgPool3dOp
5776
+ // ===----------------------------------------------------------------------===//
5777
+ void AtenAvgPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5778
+ MLIRContext *context) {
5779
+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5780
+ }
5781
+
5782
+ // ===----------------------------------------------------------------------===//
5783
+ // AtenMaxPool2dOp
5784
+ // ===----------------------------------------------------------------------===//
5785
+ void AtenMaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5786
+ MLIRContext *context) {
5787
+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5788
+ }
5789
+
5790
+ // ===----------------------------------------------------------------------===//
5791
+ // AtenMaxPool3dOp
5792
+ // ===----------------------------------------------------------------------===//
5793
+ void AtenMaxPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5794
+ MLIRContext *context) {
5795
+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5796
+ }
5797
+
5620
5798
// ===----------------------------------------------------------------------===//
5621
5799
// AtenLinalgCrossOp
5622
5800
// ===----------------------------------------------------------------------===//
0 commit comments