Skip to content

Commit 39bf1a2

Browse files
committed
[Torch] Canonicalize pool ops with single int tuple params.
1 parent e65d38e commit 39bf1a2

File tree

5 files changed

+266
-10
lines changed

5 files changed

+266
-10
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7744,6 +7744,7 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
77447744
printDefaultTorchOp(printer, *this, 6, 1);
77457745
}
77467746
}];
7747+
let hasCanonicalizer = 1;
77477748
}
77487749

77497750
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
@@ -7857,6 +7858,7 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
78577858
printDefaultTorchOp(printer, *this, 6, 1);
78587859
}
78597860
}];
7861+
let hasCanonicalizer = 1;
78607862
}
78617863

78627864
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
@@ -8001,6 +8003,7 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
80018003
printDefaultTorchOp(printer, *this, 7, 1);
80028004
}
80038005
}];
8006+
let hasCanonicalizer = 1;
80048007
}
80058008

80068009
def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [
@@ -8060,6 +8063,7 @@ def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [
80608063
printDefaultTorchOp(printer, *this, 7, 1);
80618064
}
80628065
}];
8066+
let hasCanonicalizer = 1;
80638067
}
80648068

80658069
def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6060,12 +6060,6 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
60606060
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
60616061
std::is_same<AtenOpT, AtenAvgPool1dOp>())
60626062
params.push_back(val);
6063-
6064-
if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
6065-
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6066-
if (params.size() == 1)
6067-
params.push_back(params[0]);
6068-
}
60696063
}
60706064

60716065
// Checks the validity of pooling parameters and stores them in the respective

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5617,6 +5617,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
56175617
});
56185618
}
56195619

5620+
namespace {
5621+
5622+
void expand(SmallVectorImpl<int64_t> &params, int numSpatialDims) {
5623+
if (params.size() == 1) {
5624+
for (auto _ : llvm::seq<int>(0, numSpatialDims - 1)) {
5625+
params.push_back(params[0]);
5626+
}
5627+
}
5628+
}
5629+
5630+
template <typename AtenPoolOpT>
5631+
LogicalResult expandPoolParams(AtenPoolOpT op, int numSpatialDims,
5632+
mlir::PatternRewriter &rewriter,
5633+
Value &kernelSizeList, Value &stridesList,
5634+
Value &paddingList, Value &dilationsList) {
5635+
5636+
SmallVector<int64_t, 3> kernelSizeInts, strideInts, paddingInts, dilationInts;
5637+
if (!matchPattern(op.getKernelSize(),
5638+
m_TorchListOfConstantInts(kernelSizeInts)))
5639+
return rewriter.notifyMatchFailure(
5640+
op, "Non-const kernel_size for pooling op unsupported");
5641+
5642+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
5643+
return rewriter.notifyMatchFailure(
5644+
op, "Non-const padding factor for pooling op unsupported");
5645+
5646+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
5647+
return rewriter.notifyMatchFailure(
5648+
op, "Non-const stride for pooling op unsupported");
5649+
5650+
if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5651+
std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5652+
if (!matchPattern(op.getDilation(),
5653+
m_TorchListOfConstantInts(dilationInts)))
5654+
return rewriter.notifyMatchFailure(
5655+
op, "Non-const dilation for pooling op unsupported");
5656+
5657+
if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
5658+
strideInts.size() != 1 && dilationInts.size() != 1) {
5659+
return rewriter.notifyMatchFailure(
5660+
op,
5661+
"Expected one of kernel/stride/padding/dilation to be singleton.");
5662+
}
5663+
5664+
expand(dilationInts, numSpatialDims);
5665+
5666+
} else if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
5667+
strideInts.size() != 1) {
5668+
return rewriter.notifyMatchFailure(
5669+
op, "Expected one of kernel/stride/padding to be singleton.");
5670+
}
5671+
5672+
// expand singleton elements
5673+
expand(kernelSizeInts, numSpatialDims);
5674+
expand(paddingInts, numSpatialDims);
5675+
expand(strideInts, numSpatialDims);
5676+
5677+
Location loc = op.getLoc();
5678+
5679+
SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5680+
for (auto dim : llvm::seq<int>(0, kernelSizeInts.size())) {
5681+
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
5682+
loc, rewriter.getI64IntegerAttr(kernelSizeInts[dim])));
5683+
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
5684+
loc, rewriter.getI64IntegerAttr(paddingInts[dim])));
5685+
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
5686+
loc, rewriter.getI64IntegerAttr(strideInts[dim])));
5687+
}
5688+
5689+
// set dilations separately as for AvgPool op it won't be set
5690+
for (auto dim : llvm::seq<int>(0, dilationInts.size())) {
5691+
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
5692+
loc, rewriter.getI64IntegerAttr(dilationInts[dim])));
5693+
}
5694+
5695+
auto targetListType =
5696+
Torch::ListType::get(Torch::IntType::get(op->getContext()));
5697+
kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
5698+
loc, targetListType, cstKernel);
5699+
paddingList = rewriter.create<Torch::PrimListConstructOp>(loc, targetListType,
5700+
cstPadding);
5701+
stridesList = rewriter.create<Torch::PrimListConstructOp>(loc, targetListType,
5702+
cstStrides);
5703+
dilationsList = rewriter.create<Torch::PrimListConstructOp>(
5704+
loc, targetListType, cstDilations);
5705+
5706+
return success();
5707+
}
5708+
5709+
template <typename AvgPoolOpT>
5710+
struct CanonicalizeAvgPoolWithSingleIntTuple
5711+
: public mlir::OpRewritePattern<AvgPoolOpT> {
5712+
CanonicalizeAvgPoolWithSingleIntTuple(mlir::MLIRContext *context)
5713+
: OpRewritePattern<AvgPoolOpT>(context, /*benefit=*/1) {}
5714+
5715+
LogicalResult
5716+
matchAndRewrite(AvgPoolOpT op,
5717+
mlir::PatternRewriter &rewriter) const override {
5718+
Value kernel, stride, pad, dilations;
5719+
5720+
auto numSpatialDims = 2;
5721+
if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5722+
numSpatialDims = 3;
5723+
5724+
// Attempt to expand params if necessary.
5725+
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
5726+
pad, dilations)))
5727+
return rewriter.notifyMatchFailure(op,
5728+
"Failed to expand params for pooling");
5729+
5730+
rewriter.replaceOpWithNewOp<AvgPoolOpT>(
5731+
op, op.getResult().getType(), op.getSelf(), kernel, stride, pad,
5732+
op.getCeilMode(), op.getCountIncludePad(), op.getDivisorOverride());
5733+
return success();
5734+
}
5735+
};
5736+
5737+
template <typename MaxPoolOpT>
5738+
struct CanonicalizeMaxPoolWithSingleIntTuple
5739+
: public mlir::OpRewritePattern<MaxPoolOpT> {
5740+
CanonicalizeMaxPoolWithSingleIntTuple(mlir::MLIRContext *context)
5741+
: OpRewritePattern<MaxPoolOpT>(context, /*benefit=*/1) {}
5742+
5743+
LogicalResult
5744+
matchAndRewrite(MaxPoolOpT op,
5745+
mlir::PatternRewriter &rewriter) const override {
5746+
Value kernel, stride, pad, dilations;
5747+
5748+
auto numSpatialDims = 2;
5749+
if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5750+
numSpatialDims = 3;
5751+
5752+
// Attempt to expand params if necessary.
5753+
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
5754+
pad, dilations)))
5755+
return rewriter.notifyMatchFailure(op,
5756+
"Failed to expand params for pooling");
5757+
5758+
rewriter.replaceOpWithNewOp<MaxPoolOpT>(op, op.getResult().getType(),
5759+
op.getSelf(), kernel, stride, pad,
5760+
dilations, op.getCeilMode());
5761+
return success();
5762+
}
5763+
};
5764+
} // namespace
5765+
5766+
//===----------------------------------------------------------------------===//
5767+
// AtenAvgPool2dOp
5768+
//===----------------------------------------------------------------------===//
5769+
void AtenAvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5770+
MLIRContext *context) {
5771+
patterns.add<CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5772+
}
5773+
5774+
//===----------------------------------------------------------------------===//
5775+
// AtenAvgPool3dOp
5776+
//===----------------------------------------------------------------------===//
5777+
void AtenAvgPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5778+
MLIRContext *context) {
5779+
patterns.add<CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5780+
}
5781+
5782+
//===----------------------------------------------------------------------===//
5783+
// AtenMaxPool2dOp
5784+
//===----------------------------------------------------------------------===//
5785+
void AtenMaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5786+
MLIRContext *context) {
5787+
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5788+
}
5789+
5790+
//===----------------------------------------------------------------------===//
5791+
// AtenMaxPool3dOp
5792+
//===----------------------------------------------------------------------===//
5793+
void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5794+
MLIRContext *context) {
5795+
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5796+
}
5797+
56205798
//===----------------------------------------------------------------------===//
56215799
// AtenLinalgCrossOp
56225800
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,6 @@
526526
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
527527
"Aten_TrilinearModuleSumAllDims_basic",
528528
"Aten_TrilinearModuleSumdims_basic",
529-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
530-
"AvgPool2dSingleIntTupleParamsModule_basic",
531529
"SliceOutOfLowerBoundEndIndexModule_basic",
532530
"RollModule_basic",
533531
"AdaptiveAvgPool2dDynamicNoBatch_basic",
@@ -982,8 +980,6 @@
982980
}
983981

984982
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
985-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
986-
"AvgPool2dSingleIntTupleParamsModule_basic",
987983
"BatchNorm1DModule_basic",
988984
"BatchNorm2DModule_basic",
989985
"BatchNorm3DModule_basic",
@@ -2852,6 +2848,7 @@
28522848
"AvgPool1dPadCeilPadNotIncluded_basic",
28532849
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
28542850
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
2851+
"AvgPool3dSingleIntTupleParamsModule_basic",
28552852
"BatchMlpLayerModule_basic",
28562853
"BincountMinlengthModule_basic",
28572854
"BincountModule_basic",
@@ -3043,6 +3040,7 @@
30433040
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
30443041
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
30453042
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
3043+
"MaxPool2dSingleIntTupleParamsModule_basic",
30463044
"MaxPool3dCeilModeTrueModule_basic",
30473045
"MaxPool3dLargeDatadModule_basic",
30483046
"MaxPool3dModuleRandomSimple_basic",
@@ -3054,6 +3052,7 @@
30543052
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
30553053
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
30563054
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
3055+
"MaxPool3dSingleIntTupleParamsModule_basic",
30573056
"MaxUnpool3dModule_basic",
30583057
"MaxUnpool3dModulePad0_basic",
30593058
"MaxUnpool2dModule_basic",
@@ -3554,6 +3553,7 @@
35543553
"AvgPool3dStaticModule_basic",
35553554
"AvgPool3dCountIncludePadFalse_basic",
35563555
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
3556+
"AvgPool3dSingleIntTupleParamsModule_basic",
35573557
"Conv_Transpose1dModule_basic",
35583558
"Conv_Transpose1dStaticModule_basic",
35593559
"Conv_Transpose2dStaticModule_basic",
@@ -3809,6 +3809,7 @@
38093809
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
38103810
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
38113811
"MaxPool3dWithIndicesStaticModule_basic",
3812+
"MaxPool3dSingleIntTupleParamsModule_basic",
38123813
"MeanDimEmptyDimModule_basic",
38133814
"MlGroupNormManualModule_basic",
38143815
"MlGroupNormModule_basic",
@@ -4252,6 +4253,7 @@
42524253
"AvgPool2dIntModule_basic",
42534254
"AvgPool2dStaticModule_basic",
42544255
"AvgPool2dWithoutPadModule_basic",
4256+
"AvgPool3dSingleIntTupleParamsModule_basic",
42554257
"BatchMlpLayerModule_basic",
42564258
"BernoulliFloatModule_basic",
42574259
"BernoulliModule_basic",
@@ -4659,6 +4661,7 @@
46594661
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
46604662
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
46614663
"MaxPool3dWithIndicesStaticModule_basic",
4664+
"MaxPool3dSingleIntTupleParamsModule_basic",
46624665
"MeanDimAllReduceKeepdimModule_basic",
46634666
"MeanDimAllReduceModule_basic",
46644667
"MeanDimDtypeModule_basic",

0 commit comments

Comments
 (0)