Skip to content

Commit 2f9a68c

Browse files
authored
Add canonicalization pattern for maxpool3d with indices op (#3704)
As discussed in #3652, we should replace maxpool3dwithindices with maxpool3d if indices have no user.
1 parent 55ff110 commit 2f9a68c

File tree

4 files changed

+58
-8
lines changed

4 files changed

+58
-8
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7352,6 +7352,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices",
73527352
printDefaultTorchOp(printer, *this, 6, 2);
73537353
}
73547354
}];
7355+
let hasCanonicalizer = 1;
73557356
}
73567357

73577358
def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5188,26 +5188,56 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
51885188
}
51895189

51905190
//===----------------------------------------------------------------------===//
5191-
// AtenMaxPool2dWithIndicesOp
5191+
// AtenMaxPoolWithIndicesOp
51925192
//===----------------------------------------------------------------------===//
51935193

5194-
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
5195-
RewritePatternSet &patterns, MLIRContext *context) {
5196-
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
5194+
namespace {
5195+
5196+
template <typename OpTy> struct MaxPoolWithoutIndices {
5197+
using type = OpTy;
5198+
};
5199+
5200+
template <> struct MaxPoolWithoutIndices<AtenMaxPool2dWithIndicesOp> {
5201+
using type = AtenMaxPool2dOp;
5202+
};
5203+
5204+
template <> struct MaxPoolWithoutIndices<AtenMaxPool3dWithIndicesOp> {
5205+
using type = AtenMaxPool3dOp;
5206+
};
5207+
5208+
} // namespace
5209+
5210+
template <typename OpTy>
5211+
struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern<OpTy> {
5212+
SimplifyMaxPoolWithIndices(mlir::MLIRContext *context)
5213+
: OpRewritePattern<OpTy>(context, /*benefit=*/1) {}
5214+
5215+
LogicalResult
5216+
matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override {
51975217
if (!op.getResult1().use_empty()) {
51985218
return rewriter.notifyMatchFailure(
5199-
op, "result1 of MaxPool2dWithIndices should be unused");
5219+
op, "result1 of MaxPoolWithIndices should be unused");
52005220
}
52015221

5202-
Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
5222+
Value result = rewriter.create<typename MaxPoolWithoutIndices<OpTy>::type>(
52035223
op->getLoc(), op.getResult0().getType(), op.getSelf(),
52045224
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
52055225
op.getCeilMode());
52065226

52075227
op.getResult0().replaceAllUsesWith(result);
52085228
rewriter.eraseOp(op);
52095229
return success();
5210-
});
5230+
}
5231+
};
5232+
5233+
void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
5234+
RewritePatternSet &patterns, MLIRContext *context) {
5235+
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool2dWithIndicesOp>>(context);
5236+
}
5237+
5238+
void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns(
5239+
RewritePatternSet &patterns, MLIRContext *context) {
5240+
patterns.add<SimplifyMaxPoolWithIndices<AtenMaxPool3dWithIndicesOp>>(context);
52115241
}
52125242

52135243
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ def emit_with_mutating_variants(key, **kwargs):
636636
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
637637
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
638638
emit(
639-
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
639+
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
640+
has_canonicalizer=True,
640641
)
641642
emit(
642643
"aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"

test/Dialect/Torch/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3136,6 +3136,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor
31363136

31373137
// -----
31383138

3139+
// CHECK-LABEL: @torch.aten.max_pool3d_with_indices$canonicalize(
3140+
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> {
3141+
// CHECK: %[[RET:.*]] = torch.aten.max_pool3d %[[ARG]]
3142+
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56,56],f32>
3143+
func.func @torch.aten.max_pool3d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> {
3144+
%false = torch.constant.bool false
3145+
%int1 = torch.constant.int 1
3146+
%int2 = torch.constant.int 2
3147+
%int3 = torch.constant.int 3
3148+
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
3149+
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
3150+
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
3151+
%result0, %result1 = torch.aten.max_pool3d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56,56],f32>, !torch.vtensor<[10,64,56,56,56],si64>
3152+
return %result0 : !torch.vtensor<[10,64,56,56,56],f32>
3153+
}
3154+
3155+
// -----
3156+
31393157
// CHECK-LABEL: @torch.aten.clone$no_fold(
31403158
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
31413159
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor

0 commit comments

Comments
 (0)