@@ -5601,6 +5601,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
5601
5601
});
5602
5602
}
5603
5603
5604
+ namespace {
5605
+
5606
+ void expand (SmallVectorImpl<int64_t > ¶ms, int numSpatialDims) {
5607
+ if (params.size () == 1 ) {
5608
+ for (auto _ : llvm::seq<int >(0 , numSpatialDims - 1 )) {
5609
+ params.push_back (params[0 ]);
5610
+ }
5611
+ }
5612
+ }
5613
+
5614
+ template <typename AtenPoolOpT>
5615
+ LogicalResult expandPoolParams (AtenPoolOpT op, int numSpatialDims,
5616
+ mlir::PatternRewriter &rewriter,
5617
+ Value &kernelSizeList, Value &stridesList,
5618
+ Value &paddingList, Value &dilationsList) {
5619
+
5620
+ SmallVector<int64_t , 3 > kernelSizeInts, strideInts, paddingInts, dilationInts;
5621
+ if (!matchPattern (op.getKernelSize (),
5622
+ m_TorchListOfConstantInts (kernelSizeInts)))
5623
+ return rewriter.notifyMatchFailure (
5624
+ op, " Non-const kernel_size for pooling op unsupported" );
5625
+
5626
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
5627
+ return rewriter.notifyMatchFailure (
5628
+ op, " Non-const padding factor for pooling op unsupported" );
5629
+
5630
+ if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
5631
+ return rewriter.notifyMatchFailure (
5632
+ op, " Non-const stride for pooling op unsupported" );
5633
+
5634
+ if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5635
+ std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5636
+ if (!matchPattern (op.getDilation (),
5637
+ m_TorchListOfConstantInts (dilationInts)))
5638
+ return rewriter.notifyMatchFailure (
5639
+ op, " Non-const dilation for pooling op unsupported" );
5640
+
5641
+ if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5642
+ strideInts.size () != 1 && dilationInts.size () != 1 ) {
5643
+ return rewriter.notifyMatchFailure (
5644
+ op,
5645
+ " Expected one of kernel/stride/padding/dilation to be singleton." );
5646
+ }
5647
+
5648
+ expand (dilationInts, numSpatialDims);
5649
+
5650
+ } else if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5651
+ strideInts.size () != 1 ) {
5652
+ return rewriter.notifyMatchFailure (
5653
+ op, " Expected one of kernel/stride/padding to be singleton." );
5654
+ }
5655
+
5656
+ // expand singleton elements
5657
+ expand (kernelSizeInts, numSpatialDims);
5658
+ expand (paddingInts, numSpatialDims);
5659
+ expand (strideInts, numSpatialDims);
5660
+
5661
+ Location loc = op.getLoc ();
5662
+
5663
+ SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5664
+ for (auto dim : llvm::seq<int >(0 , kernelSizeInts.size ())) {
5665
+ cstKernel.push_back (rewriter.create <Torch::ConstantIntOp>(
5666
+ loc, rewriter.getI64IntegerAttr (kernelSizeInts[dim])));
5667
+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
5668
+ loc, rewriter.getI64IntegerAttr (paddingInts[dim])));
5669
+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
5670
+ loc, rewriter.getI64IntegerAttr (strideInts[dim])));
5671
+ }
5672
+
5673
+ // set dilations separately as for AvgPool op it won't be set
5674
+ for (auto dim : llvm::seq<int >(0 , dilationInts.size ())) {
5675
+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
5676
+ loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
5677
+ }
5678
+
5679
+ auto targetListType =
5680
+ Torch::ListType::get (Torch::IntType::get (op->getContext ()));
5681
+ kernelSizeList = rewriter.create <Torch::PrimListConstructOp>(
5682
+ loc, targetListType, cstKernel);
5683
+ paddingList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5684
+ cstPadding);
5685
+ stridesList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5686
+ cstStrides);
5687
+ dilationsList = rewriter.create <Torch::PrimListConstructOp>(
5688
+ loc, targetListType, cstDilations);
5689
+
5690
+ return success ();
5691
+ }
5692
+
5693
+ template <typename AvgPoolOpT>
5694
+ struct CanonicalizeAvgPoolWithSingleIntTuple
5695
+ : public mlir::OpRewritePattern<AvgPoolOpT> {
5696
+ CanonicalizeAvgPoolWithSingleIntTuple (mlir::MLIRContext *context)
5697
+ : OpRewritePattern<AvgPoolOpT>(context, /* benefit=*/ 1 ) {}
5698
+
5699
+ LogicalResult
5700
+ matchAndRewrite (AvgPoolOpT op,
5701
+ mlir::PatternRewriter &rewriter) const override {
5702
+ Value kernel, stride, pad, dilations;
5703
+
5704
+ auto numSpatialDims = 2 ;
5705
+ if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5706
+ numSpatialDims = 3 ;
5707
+
5708
+ // Attempt to expand params if necessary.
5709
+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5710
+ pad, dilations)))
5711
+ return rewriter.notifyMatchFailure (op,
5712
+ " Failed to expand params for pooling" );
5713
+
5714
+ rewriter.replaceOpWithNewOp <AvgPoolOpT>(
5715
+ op, op.getResult ().getType (), op.getSelf (), kernel, stride, pad,
5716
+ op.getCeilMode (), op.getCountIncludePad (), op.getDivisorOverride ());
5717
+ return success ();
5718
+ }
5719
+ };
5720
+
5721
+ template <typename MaxPoolOpT>
5722
+ struct CanonicalizeMaxPoolWithSingleIntTuple
5723
+ : public mlir::OpRewritePattern<MaxPoolOpT> {
5724
+ CanonicalizeMaxPoolWithSingleIntTuple (mlir::MLIRContext *context)
5725
+ : OpRewritePattern<MaxPoolOpT>(context, /* benefit=*/ 1 ) {}
5726
+
5727
+ LogicalResult
5728
+ matchAndRewrite (MaxPoolOpT op,
5729
+ mlir::PatternRewriter &rewriter) const override {
5730
+ Value kernel, stride, pad, dilations;
5731
+
5732
+ auto numSpatialDims = 2 ;
5733
+ if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5734
+ numSpatialDims = 3 ;
5735
+
5736
+ // Attempt to expand params if necessary.
5737
+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5738
+ pad, dilations)))
5739
+ return rewriter.notifyMatchFailure (op,
5740
+ " Failed to expand params for pooling" );
5741
+
5742
+ rewriter.replaceOpWithNewOp <MaxPoolOpT>(op, op.getResult ().getType (),
5743
+ op.getSelf (), kernel, stride, pad,
5744
+ dilations, op.getCeilMode ());
5745
+ return success ();
5746
+ }
5747
+ };
5748
+ } // namespace
5749
+
5750
+ // ===----------------------------------------------------------------------===//
5751
+ // AtenAvgPool2dOp
5752
+ // ===----------------------------------------------------------------------===//
5753
+ void AtenAvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5754
+ MLIRContext *context) {
5755
+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5756
+ }
5757
+
5758
+ // ===----------------------------------------------------------------------===//
5759
+ // AtenAvgPool3dOp
5760
+ // ===----------------------------------------------------------------------===//
5761
+ void AtenAvgPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5762
+ MLIRContext *context) {
5763
+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5764
+ }
5765
+
5766
+ // ===----------------------------------------------------------------------===//
5767
+ // AtenMaxPool2dOp
5768
+ // ===----------------------------------------------------------------------===//
5769
+ void AtenMaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5770
+ MLIRContext *context) {
5771
+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5772
+ }
5773
+
5774
+ // ===----------------------------------------------------------------------===//
5775
+ // AtenMaxPool3dOp
5776
+ // ===----------------------------------------------------------------------===//
5777
+ void AtenMaxPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5778
+ MLIRContext *context) {
5779
+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5780
+ }
5781
+
5604
5782
// ===----------------------------------------------------------------------===//
5605
5783
// AtenLinalgCrossOp
5606
5784
// ===----------------------------------------------------------------------===//
0 commit comments