Skip to content

Commit 903533a

Browse files
committed
[Torch] Canonicalize pool ops with single int tuple params.
1 parent acf7fdd commit 903533a

File tree

5 files changed

+266
-10
lines changed

5 files changed

+266
-10
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7744,6 +7744,7 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
77447744
printDefaultTorchOp(printer, *this, 6, 1);
77457745
}
77467746
}];
7747+
let hasCanonicalizer = 1;
77477748
}
77487749

77497750
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
@@ -7857,6 +7858,7 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
78577858
printDefaultTorchOp(printer, *this, 6, 1);
78587859
}
78597860
}];
7861+
let hasCanonicalizer = 1;
78607862
}
78617863

78627864
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
@@ -8001,6 +8003,7 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
80018003
printDefaultTorchOp(printer, *this, 7, 1);
80028004
}
80038005
}];
8006+
let hasCanonicalizer = 1;
80048007
}
80058008

80068009
def Torch_AtenAvgPool2dBackwardOp : Torch_Op<"aten.avg_pool2d_backward", [
@@ -8060,6 +8063,7 @@ def Torch_AtenAvgPool3dOp : Torch_Op<"aten.avg_pool3d", [
80608063
printDefaultTorchOp(printer, *this, 7, 1);
80618064
}
80628065
}];
8066+
let hasCanonicalizer = 1;
80638067
}
80648068

80658069
def Torch_AtenAvgPool3dBackwardOp : Torch_Op<"aten.avg_pool3d_backward", [

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6046,12 +6046,6 @@ void expandPoolParams(AtenOpT op, SmallVectorImpl<int64_t> &params,
60466046
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
60476047
std::is_same<AtenOpT, AtenAvgPool1dOp>())
60486048
params.push_back(val);
6049-
6050-
if constexpr (std::is_same<AtenOpT, AtenMaxPool2dOp>() ||
6051-
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6052-
if (params.size() == 1)
6053-
params.push_back(params[0]);
6054-
}
60556049
}
60566050

60576051
// Checks the validity of pooling parameters and stores them in the respective

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5601,6 +5601,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
56015601
});
56025602
}
56035603

5604+
namespace {
5605+
5606+
void expand(SmallVectorImpl<int64_t> &params, int numSpatialDims) {
5607+
if (params.size() == 1) {
5608+
for (auto _ : llvm::seq<int>(0, numSpatialDims - 1)) {
5609+
params.push_back(params[0]);
5610+
}
5611+
}
5612+
}
5613+
5614+
template <typename AtenPoolOpT>
5615+
LogicalResult expandPoolParams(AtenPoolOpT op, int numSpatialDims,
5616+
mlir::PatternRewriter &rewriter,
5617+
Value &kernelSizeList, Value &stridesList,
5618+
Value &paddingList, Value &dilationsList) {
5619+
5620+
SmallVector<int64_t, 3> kernelSizeInts, strideInts, paddingInts, dilationInts;
5621+
if (!matchPattern(op.getKernelSize(),
5622+
m_TorchListOfConstantInts(kernelSizeInts)))
5623+
return rewriter.notifyMatchFailure(
5624+
op, "Non-const kernel_size for pooling op unsupported");
5625+
5626+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
5627+
return rewriter.notifyMatchFailure(
5628+
op, "Non-const padding factor for pooling op unsupported");
5629+
5630+
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
5631+
return rewriter.notifyMatchFailure(
5632+
op, "Non-const stride for pooling op unsupported");
5633+
5634+
if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5635+
std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5636+
if (!matchPattern(op.getDilation(),
5637+
m_TorchListOfConstantInts(dilationInts)))
5638+
return rewriter.notifyMatchFailure(
5639+
op, "Non-const dilation for pooling op unsupported");
5640+
5641+
if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
5642+
strideInts.size() != 1 && dilationInts.size() != 1) {
5643+
return rewriter.notifyMatchFailure(
5644+
op,
5645+
"Expected one of kernel/stride/padding/dilation to be singleton.");
5646+
}
5647+
5648+
expand(dilationInts, numSpatialDims);
5649+
5650+
} else if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
5651+
strideInts.size() != 1) {
5652+
return rewriter.notifyMatchFailure(
5653+
op, "Expected one of kernel/stride/padding to be singleton.");
5654+
}
5655+
5656+
// expand singleton elements
5657+
expand(kernelSizeInts, numSpatialDims);
5658+
expand(paddingInts, numSpatialDims);
5659+
expand(strideInts, numSpatialDims);
5660+
5661+
Location loc = op.getLoc();
5662+
5663+
SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5664+
for (auto dim : llvm::seq<int>(0, kernelSizeInts.size())) {
5665+
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
5666+
loc, rewriter.getI64IntegerAttr(kernelSizeInts[dim])));
5667+
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
5668+
loc, rewriter.getI64IntegerAttr(paddingInts[dim])));
5669+
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
5670+
loc, rewriter.getI64IntegerAttr(strideInts[dim])));
5671+
}
5672+
5673+
// set dilations separately as for AvgPool op it won't be set
5674+
for (auto dim : llvm::seq<int>(0, dilationInts.size())) {
5675+
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
5676+
loc, rewriter.getI64IntegerAttr(dilationInts[dim])));
5677+
}
5678+
5679+
auto targetListType =
5680+
Torch::ListType::get(Torch::IntType::get(op->getContext()));
5681+
kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
5682+
loc, targetListType, cstKernel);
5683+
paddingList = rewriter.create<Torch::PrimListConstructOp>(loc, targetListType,
5684+
cstPadding);
5685+
stridesList = rewriter.create<Torch::PrimListConstructOp>(loc, targetListType,
5686+
cstStrides);
5687+
dilationsList = rewriter.create<Torch::PrimListConstructOp>(
5688+
loc, targetListType, cstDilations);
5689+
5690+
return success();
5691+
}
5692+
5693+
template <typename AvgPoolOpT>
5694+
struct CanonicalizeAvgPoolWithSingleIntTuple
5695+
: public mlir::OpRewritePattern<AvgPoolOpT> {
5696+
CanonicalizeAvgPoolWithSingleIntTuple(mlir::MLIRContext *context)
5697+
: OpRewritePattern<AvgPoolOpT>(context, /*benefit=*/1) {}
5698+
5699+
LogicalResult
5700+
matchAndRewrite(AvgPoolOpT op,
5701+
mlir::PatternRewriter &rewriter) const override {
5702+
Value kernel, stride, pad, dilations;
5703+
5704+
auto numSpatialDims = 2;
5705+
if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5706+
numSpatialDims = 3;
5707+
5708+
// Attempt to expand params if necessary.
5709+
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
5710+
pad, dilations)))
5711+
return rewriter.notifyMatchFailure(op,
5712+
"Failed to expand params for pooling");
5713+
5714+
rewriter.replaceOpWithNewOp<AvgPoolOpT>(
5715+
op, op.getResult().getType(), op.getSelf(), kernel, stride, pad,
5716+
op.getCeilMode(), op.getCountIncludePad(), op.getDivisorOverride());
5717+
return success();
5718+
}
5719+
};
5720+
5721+
template <typename MaxPoolOpT>
5722+
struct CanonicalizeMaxPoolWithSingleIntTuple
5723+
: public mlir::OpRewritePattern<MaxPoolOpT> {
5724+
CanonicalizeMaxPoolWithSingleIntTuple(mlir::MLIRContext *context)
5725+
: OpRewritePattern<MaxPoolOpT>(context, /*benefit=*/1) {}
5726+
5727+
LogicalResult
5728+
matchAndRewrite(MaxPoolOpT op,
5729+
mlir::PatternRewriter &rewriter) const override {
5730+
Value kernel, stride, pad, dilations;
5731+
5732+
auto numSpatialDims = 2;
5733+
if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5734+
numSpatialDims = 3;
5735+
5736+
// Attempt to expand params if necessary.
5737+
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
5738+
pad, dilations)))
5739+
return rewriter.notifyMatchFailure(op,
5740+
"Failed to expand params for pooling");
5741+
5742+
rewriter.replaceOpWithNewOp<MaxPoolOpT>(op, op.getResult().getType(),
5743+
op.getSelf(), kernel, stride, pad,
5744+
dilations, op.getCeilMode());
5745+
return success();
5746+
}
5747+
};
5748+
} // namespace
5749+
5750+
//===----------------------------------------------------------------------===//
5751+
// AtenAvgPool2dOp
5752+
//===----------------------------------------------------------------------===//
5753+
void AtenAvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5754+
MLIRContext *context) {
5755+
patterns.add<CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5756+
}
5757+
5758+
//===----------------------------------------------------------------------===//
5759+
// AtenAvgPool3dOp
5760+
//===----------------------------------------------------------------------===//
5761+
void AtenAvgPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5762+
MLIRContext *context) {
5763+
patterns.add<CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5764+
}
5765+
5766+
//===----------------------------------------------------------------------===//
5767+
// AtenMaxPool2dOp
5768+
//===----------------------------------------------------------------------===//
5769+
void AtenMaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5770+
MLIRContext *context) {
5771+
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5772+
}
5773+
5774+
//===----------------------------------------------------------------------===//
5775+
// AtenMaxPool3dOp
5776+
//===----------------------------------------------------------------------===//
5777+
void AtenMaxPool3dOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
5778+
MLIRContext *context) {
5779+
patterns.add<CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5780+
}
5781+
56045782
//===----------------------------------------------------------------------===//
56055783
// AtenLinalgCrossOp
56065784
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,6 @@
531531
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
532532
"Aten_TrilinearModuleSumAllDims_basic",
533533
"Aten_TrilinearModuleSumdims_basic",
534-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
535-
"AvgPool2dSingleIntTupleParamsModule_basic",
536534
"SliceOutOfLowerBoundEndIndexModule_basic",
537535
"RollModule_basic",
538536
}
@@ -977,8 +975,6 @@
977975
}
978976

979977
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
980-
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
981-
"AvgPool2dSingleIntTupleParamsModule_basic",
982978
"BatchNorm1DModule_basic",
983979
"BatchNorm2DModule_basic",
984980
"BatchNorm3DModule_basic",
@@ -2833,6 +2829,7 @@
28332829
"AvgPool1dPadCeilPadNotIncluded_basic",
28342830
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
28352831
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
2832+
"AvgPool3dSingleIntTupleParamsModule_basic",
28362833
"BatchMlpLayerModule_basic",
28372834
"BincountMinlengthModule_basic",
28382835
"BincountModule_basic",
@@ -3020,6 +3017,7 @@
30203017
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
30213018
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
30223019
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
3020+
"MaxPool2dSingleIntTupleParamsModule_basic",
30233021
"MaxPool3dCeilModeTrueModule_basic",
30243022
"MaxPool3dLargeDatadModule_basic",
30253023
"MaxPool3dModuleRandomSimple_basic",
@@ -3031,6 +3029,7 @@
30313029
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
30323030
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
30333031
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
3032+
"MaxPool3dSingleIntTupleParamsModule_basic",
30343033
"MaxUnpool3dModule_basic",
30353034
"MaxUnpool3dModulePad0_basic",
30363035
"MeanDimEmptyDimModule_basic",
@@ -3516,6 +3515,7 @@
35163515
"AvgPool3dStaticModule_basic",
35173516
"AvgPool3dCountIncludePadFalse_basic",
35183517
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
3518+
"AvgPool3dSingleIntTupleParamsModule_basic",
35193519
"Conv_Transpose1dModule_basic",
35203520
"Conv_Transpose1dStaticModule_basic",
35213521
"Conv_Transpose2dStaticModule_basic",
@@ -3769,6 +3769,7 @@
37693769
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
37703770
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
37713771
"MaxPool3dWithIndicesStaticModule_basic",
3772+
"MaxPool3dSingleIntTupleParamsModule_basic",
37723773
"MeanDimEmptyDimModule_basic",
37733774
"MlGroupNormManualModule_basic",
37743775
"MlGroupNormModule_basic",
@@ -4186,6 +4187,7 @@
41864187
"AvgPool2dIntModule_basic",
41874188
"AvgPool2dStaticModule_basic",
41884189
"AvgPool2dWithoutPadModule_basic",
4190+
"AvgPool3dSingleIntTupleParamsModule_basic",
41894191
"BatchMlpLayerModule_basic",
41904192
"BernoulliFloatModule_basic",
41914193
"BernoulliModule_basic",
@@ -4593,6 +4595,7 @@
45934595
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
45944596
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
45954597
"MaxPool3dWithIndicesStaticModule_basic",
4598+
"MaxPool3dSingleIntTupleParamsModule_basic",
45964599
"MeanDimAllReduceKeepdimModule_basic",
45974600
"MeanDimAllReduceModule_basic",
45984601
"MeanDimDtypeModule_basic",

0 commit comments

Comments
 (0)