Skip to content

Commit 67be5fe

Browse files
committed
Add canonicalizer in torch_ods_gen.py + update error message
1 parent 94ad9d2 commit 67be5fe

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5708,8 +5708,8 @@ struct CanonicalizeAvgPoolWithSingleIntTuple
57085708
// Attempt to expand params if necessary.
57095709
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
57105710
pad, dilations)))
5711-
return rewriter.notifyMatchFailure(op,
5712-
"Failed to expand params for pooling");
5711+
return rewriter.notifyMatchFailure(
5712+
op, "Failed to expand params for AvgPooling");
57135713

57145714
rewriter.replaceOpWithNewOp<AvgPoolOpT>(
57155715
op, op.getResult().getType(), op.getSelf(), kernel, stride, pad,
@@ -5736,8 +5736,8 @@ struct CanonicalizeMaxPoolWithSingleIntTuple
57365736
// Attempt to expand params if necessary.
57375737
if (failed(expandPoolParams(op, numSpatialDims, rewriter, kernel, stride,
57385738
pad, dilations)))
5739-
return rewriter.notifyMatchFailure(op,
5740-
"Failed to expand params for pooling");
5739+
return rewriter.notifyMatchFailure(
5740+
op, "Failed to expand params for MaxPooling");
57415741

57425742
rewriter.replaceOpWithNewOp<MaxPoolOpT>(op, op.getResult().getType(),
57435743
op.getSelf(), kernel, stride, pad,

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,10 @@ def emit_with_mutating_variants(key, **kwargs):
657657
emit(
658658
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
659659
)
660-
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
660+
emit(
661+
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)",
662+
has_canonicalizer=True,
663+
)
661664
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
662665
emit(
663666
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
@@ -666,7 +669,10 @@ def emit_with_mutating_variants(key, **kwargs):
666669
emit(
667670
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
668671
)
669-
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
672+
emit(
673+
"aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)",
674+
has_canonicalizer=True,
675+
)
670676
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
671677
emit(
672678
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
@@ -677,13 +683,15 @@ def emit_with_mutating_variants(key, **kwargs):
677683
)
678684
emit("aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)")
679685
emit(
680-
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
686+
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)",
687+
has_canonicalizer=True,
681688
)
682689
emit(
683690
"aten::avg_pool2d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
684691
)
685692
emit(
686-
"aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
693+
"aten::avg_pool3d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)",
694+
has_canonicalizer=True,
687695
)
688696
emit(
689697
"aten::avg_pool3d_backward : (Tensor, Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"

0 commit comments

Comments
 (0)