Skip to content

Commit 244f4b6

Browse files
authored
[Torch] Canonicalize pool ops with single int tuple params. (#4250)
Fixes #3885 by repeating the single int to match with expected spatial dims.
1 parent 1cf2871 commit 244f4b6

File tree

7 files changed

+304
-14
lines changed

7 files changed

+304
-14
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7744,6 +7744,7 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
77447744
printDefaultTorchOp(printer, *this, 6, 1);
77457745
}
77467746
}];
7747+
let hasCanonicalizer = 1;
77477748
}
77487749

77497750
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
@@ -7857,6 +7858,7 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
78577858
printDefaultTorchOp(printer, *this, 6, 1);
78587859
}
78597860
}];
7861+
let hasCanonicalizer = 1;
78607862
}
78617863

78627864
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
@@ -8001,6 +8003,7 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
80018003
printDefaultTorchOp(printer, *this, 7, 1);
80028004
}
80038005
}];
8006+
let hasCanonicalizer = 1;
80048007
}
80058008

80068009
def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [
@@ -8060,6 +8063,7 @@ def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [
80608063
printDefaultTorchOp(printer, *this, 7, 1);
80618064
}
80628065
}];
8066+
let hasCanonicalizer = 1;
80638067
}
80648068

80658069
def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6275,12 +6275,6 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
62756275
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
62766276
std::is_same<AtenOpT, AtenAvgPool1dOp>())
62776277
params.push_back(val);
6278-
6279-
if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
6280-
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6281-
if (params.size() == 1)
6282-
params.push_back(params[0]);
6283-
}
62846278
}
62856279

62866280
// Checks the validity of pooling parameters and stores them in the respective

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5720,6 +5720,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
57205720
});
57215721
}
57225722

5723+
namespace {
5724+
5725+
void expand(SmallVectorImpl<int64_t> &params, int numSpatialDims) {
5726+
if (params.size() == 1) {
5727+
for ([[maybe_unused]] int dim : llvm::seq<int>(0, numSpatialDims - 1)) {
5728+
params.push_back(params[0]);
5729+
}
5730+
}
5731+
}
5732+
5733+
template <typename AtenPoolOpT>
5734+
LogicalResult expandPoolParams(AtenPoolOpT op, int numSpatialDims,
5735+
mlir::PatternRewriter &rewriter,
5736+
Value &kernelSizeList, Value &stridesList,
5737+
Value &paddingList, Value &dilationsList) {
5738+
5739+
SmallVector<int64_t, 3> kernelSizeInts, strideInts, paddingInts, dilationInts;
5740+
if (!matchPattern(op.getKernelSize(),
5741+
m_TorchListOfConstantInts(kernelSizeInts)))
5742+
return rewriter.notifyMatchFailure(
5743+
op, "Non-const kernel_size for pooling op unsupported");
5744+
5745+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
5746+
return rewriter.notifyMatchFailure(
5747+
op, "Non-const padding factor for pooling op unsupported");
5748+
5749+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
5750+
return rewriter.notifyMatchFailure(
5751+
op, "Non-const stride for pooling op unsupported");
5752+
5753+
if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5754+
std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5755+
if (!matchPattern(op.getDilation(),
5756+
m_TorchListOfConstantInts(dilationInts)))
5757+
return rewriter.notifyMatchFailure(
5758+
op, "Non-const dilation for pooling op unsupported");
5759+
5760+
if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
5761+
strideInts.size() != 1 && dilationInts.size() != 1) {
5762+
return rewriter.notifyMatchFailure(
5763+
op,
5764+
"Expected one of kernel/stride/padding/dilation to be singleton.");
5765+
}
5766+
5767+
expand(dilationInts, numSpatialDims);
5768+
5769+
} else if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
5770+
strideInts.size() != 1) {
5771+
return rewriter.notifyMatchFailure(
5772+
op, "Expected one of kernel/stride/padding to be singleton.");
5773+
}
5774+
5775+
// expand singleton elements
5776+
expand(kernelSizeInts, numSpatialDims);
5777+
expand(paddingInts, numSpatialDims);
5778+
expand(strideInts, numSpatialDims);
5779+
5780+
Location loc = op.getLoc();
5781+
5782+
SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5783+
for (auto dim : llvm::seq<int>(0, kernelSizeInts.size())) {
5784+
cstKernel.push_back(Torch::ConstantIntOp::create(
5785+
rewriter, loc, rewriter.getI64IntegerAttr(kernelSizeInts[dim])));
5786+
cstPadding.push_back(Torch::ConstantIntOp::create(
5787+
rewriter, loc, rewriter.getI64IntegerAttr(paddingInts[dim])));
5788+
cstStrides.push_back(Torch::ConstantIntOp::create(
5789+
rewriter, loc, rewriter.getI64IntegerAttr(strideInts[dim])));
5790+
}
5791+
5792+
// set dilations separately as for AvgPool op it won't be set
5793+
for (auto dim : llvm::seq<int>(0, dilationInts.size())) {
5794+
cstDilations.push_back(Torch::ConstantIntOp::create(
5795+
rewriter, loc, rewriter.getI64IntegerAttr(dilationInts[dim])));
5796+
}
5797+
5798+
auto targetListType =
5799+
Torch::ListType::get(Torch::IntType::get(op->getContext()));
5800+
kernelSizeList = Torch::PrimListConstructOp::create(
5801+
rewriter, loc, targetListType, cstKernel);
5802+
paddingList = Torch::PrimListConstructOp::create(rewriter, loc,
5803+
targetListType, cstPadding);
5804+
stridesList = Torch::PrimListConstructOp::create(rewriter, loc,
5805+
targetListType, cstStrides);
5806+
dilationsList = Torch::PrimListConstructOp::create(
5807+
rewriter, loc, targetListType, cstDilations);
5808+
5809+
return success();
5810+
}
5811+
5812+
template <typename AvgPoolOpT>
5813+
struct CanonicalizeAvgPoolWithSingleIntTuple
5814+
: public mlir::OpRewritePattern<AvgPoolOpT> {
5815+
CanonicalizeAvgPoolWithSingleIntTuple(mlir::MLIRContext *context)
5816+
: OpRewritePattern<AvgPoolOpT>(context, /*benefit=*/1) {}
5817+
5818+
LogicalResult
5819+
matchAndRewrite(AvgPoolOpT op,
5820+
mlir::PatternRewriter &rewriter) const override {
5821+
Value kernel, stride, pad, dilations;
5822+
5823+
auto numSpatialDims = 2;
5824+
if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5825+
numSpatialDims = 3;
5826+
5827+
// Attempt to expand params if necessary.
5828+
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
5829+
pad, dilations)))
5830+
return rewriter.notifyMatchFailure(
5831+
op, "Failed to expand params for AvgPooling");
5832+
5833+
rewriter.replaceOpWithNewOp<AvgPoolOpT>(
5834+
op, op.getResult().getType(), op.getSelf(), kernel, stride, pad,
5835+
op.getCeilMode(), op.getCountIncludePad(), op.getDivisorOverride());
5836+
return success();
5837+
}
5838+
};
5839+
5840+
template <typename MaxPoolOpT>
5841+
struct CanonicalizeMaxPoolWithSingleIntTuple
5842+
: public mlir::OpRewritePattern<MaxPoolOpT> {
5843+
CanonicalizeMaxPoolWithSingleIntTuple(mlir::MLIRContext *context)
5844+
: OpRewritePattern<MaxPoolOpT>(context, /*benefit=*/1) {}
5845+
5846+
LogicalResult
5847+
matchAndRewrite(MaxPoolOpT op,
5848+
mlir::PatternRewriter &rewriter) const override {
5849+
Value kernel, stride, pad, dilations;
5850+
5851+
auto numSpatialDims = 2;
5852+
if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5853+
numSpatialDims = 3;
5854+
5855+
// Attempt to expand params if necessary.
5856+
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
5857+
pad, dilations)))
5858+
return rewriter.notifyMatchFailure(
5859+
op, "Failed to expand params for MaxPooling");
5860+
5861+
rewriter.replaceOpWithNewOp<MaxPoolOpT>(op, op.getResult().getType(),
5862+
op.getSelf(), kernel, stride, pad,
5863+
dilations, op.getCeilMode());
5864+
return success();
5865+
}
5866+
};
5867+
} // namespace
5868+
5869+
//===----------------------------------------------------------------------===//
5870+
// AtenAvgPool2dOp
5871+
//===----------------------------------------------------------------------===//
5872+
void AtenAvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5873+
MLIRContext *context) {
5874+
patterns.add<CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5875+
}
5876+
5877+
//===----------------------------------------------------------------------===//
5878+
// AtenAvgPool3dOp
5879+
//===----------------------------------------------------------------------===//
5880+
void AtenAvgPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5881+
MLIRContext *context) {
5882+
patterns.add<CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5883+
}
5884+
5885+
//===----------------------------------------------------------------------===//
5886+
// AtenMaxPool2dOp
5887+
//===----------------------------------------------------------------------===//
5888+
void AtenMaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5889+
MLIRContext *context) {
5890+
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5891+
}
5892+
5893+
//===----------------------------------------------------------------------===//
5894+
// AtenMaxPool3dOp
5895+
//===----------------------------------------------------------------------===//
5896+
void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5897+
MLIRContext *context) {
5898+
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5899+
}
5900+
57235901
//===----------------------------------------------------------------------===//
57245902
// AtenLinalgCrossOp
57255903
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -533,8 +533,6 @@
533533
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
534534
"Aten_TrilinearModuleSumAllDims_basic",
535535
"Aten_TrilinearModuleSumdims_basic",
536-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
537-
"AvgPool2dSingleIntTupleParamsModule_basic",
538536
"SliceOutOfLowerBoundEndIndexModule_basic",
539537
"RollModule_basic",
540538
"AdaptiveAvgPool2dDynamicNoBatch_basic",
@@ -990,8 +988,6 @@
990988
}
991989

992990
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
993-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
994-
"AvgPool2dSingleIntTupleParamsModule_basic",
995991
"BatchNorm1DModule_basic",
996992
"BatchNorm2DModule_basic",
997993
"BatchNorm3DModule_basic",
@@ -2863,6 +2859,7 @@
28632859
"AvgPool1dPadCeilPadNotIncluded_basic",
28642860
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
28652861
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
2862+
"AvgPool3dSingleIntTupleStrideModule_basic",
28662863
"BatchMlpLayerModule_basic",
28672864
"BincountMinlengthModule_basic",
28682865
"BincountModule_basic",
@@ -3059,6 +3056,7 @@
30593056
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
30603057
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
30613058
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
3059+
"MaxPool2dSingleIntTupleKernelModule_basic",
30623060
"MaxPool3dCeilModeTrueModule_basic",
30633061
"MaxPool3dLargeDatadModule_basic",
30643062
"MaxPool3dModuleRandomSimple_basic",
@@ -3070,6 +3068,7 @@
30703068
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
30713069
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
30723070
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
3071+
"MaxPool3dSingleIntTupleDilationModule_basic",
30733072
"MaxUnpool3dModule_basic",
30743073
"MaxUnpool3dModulePad0_basic",
30753074
"MaxUnpool2dModule_basic",
@@ -3579,6 +3578,7 @@
35793578
"AvgPool3dStaticModule_basic",
35803579
"AvgPool3dCountIncludePadFalse_basic",
35813580
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
3581+
"AvgPool3dSingleIntTupleStrideModule_basic",
35823582
"Conv_Transpose1dModule_basic",
35833583
"Conv_Transpose1dStaticModule_basic",
35843584
"Conv_Transpose3dModule_basic",
@@ -3821,6 +3821,7 @@
38213821
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
38223822
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
38233823
"MaxPool3dWithIndicesStaticModule_basic",
3824+
"MaxPool3dSingleIntTupleDilationModule_basic",
38243825
"MeanDimEmptyDimModule_basic",
38253826
"MlGroupNormManualModule_basic",
38263827
"MlGroupNormModule_basic",
@@ -4270,6 +4271,7 @@
42704271
"AvgPool2dIntModule_basic",
42714272
"AvgPool2dStaticModule_basic",
42724273
"AvgPool2dWithoutPadModule_basic",
4274+
"AvgPool3dSingleIntTupleStrideModule_basic",
42734275
"BatchMlpLayerModule_basic",
42744276
"BernoulliFloatModule_basic",
42754277
"BernoulliModule_basic",
@@ -4684,6 +4686,7 @@
46844686
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
46854687
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
46864688
"MaxPool3dWithIndicesStaticModule_basic",
4689+
"MaxPool3dSingleIntTupleDilationModule_basic",
46874690
"MeanDimAllReduceKeepdimModule_basic",
46884691
"MeanDimAllReduceModule_basic",
46894692
"MeanDimDtypeModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,10 @@ def emit_with_mutating_variants(key, **kwargs):
657657
emit(
658658
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
659659
)
660-
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
660+
emit(
661+
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)",
662+
has_canonicalizer=True,
663+
)
661664
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
662665
emit(
663666
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
@@ -666,7 +669,10 @@ def emit_with_mutating_variants(key, **kwargs):
666669
emit(
667670
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
668671
)
669-
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
672+
emit(
673+
"aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)",
674+
has_canonicalizer=True,
675+
)
670676
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
671677
emit(
672678
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
@@ -677,13 +683,15 @@ def emit_with_mutating_variants(key, **kwargs):
677683
)
678684
emit("aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)")
679685
emit(
680-
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
686+
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)",
687+
has_canonicalizer=True,
681688
)
682689
emit(
683690
"aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
684691
)
685692
emit(
686-
"aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
693+
"aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)",
694+
has_canonicalizer=True,
687695
)
688696
emit(
689697
"aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"

0 commit comments

Comments
 (0)