Skip to content

Commit 8bc028a

Browse files
ljfitzcathyzhyi
authored andcommitted
Fold __is__ and unchecked_cast of derefine
The added e2e maxpool testcase from #545 was not getting a static shape due to an unfolded prim.If when RefineTypes was called. This was because of unfolded torch.iaten.__is__ and torch.prim.unchecked_cast operators with torch.derefine operands.
1 parent e1b3e5b commit 8bc028a

File tree

5 files changed

+60
-6
lines changed

5 files changed

+60
-6
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,28 @@ def forward(self, x):
268268
def MaxPool2dModule_basic(module, tu: TestUtils):
269269
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
270270

271+
class MaxPool2dStaticModule(torch.nn.Module):
272+
def __init__(self):
273+
super().__init__()
274+
self.mp2d = torch.nn.MaxPool2d(kernel_size=[3, 3],
275+
stride=[2, 2],
276+
padding=[1, 1],
277+
dilation=[1, 1])
278+
279+
@export
280+
@annotate_args([
281+
None,
282+
([1, 64, 112, 112], torch.float32, True),
283+
])
284+
def forward(self, x):
285+
return self.mp2d(x)
286+
287+
288+
@register_test_case(module_factory=lambda: MaxPool2dStaticModule())
289+
def MaxPool2dStaticModule_basic(module, tu: TestUtils):
290+
module.forward(tu.rand(1, 64, 112, 112))
291+
292+
# ==============================================================================
271293

272294
class ConstantPad2dStaticModule(torch.nn.Module):
273295
def __init__(self):

include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
201201
AnyTorchType:$result
202202
);
203203
let assemblyFormat = "$x attr-dict `:` qualified(type($x)) `->` qualified(type($result))";
204+
let hasFolder = 1;
204205
}
205206

206207
def Torch_PrimPrintOp : Torch_Op<"prim.Print", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,19 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
390390

391391
template <typename OpTy>
392392
static OpFoldResult atenIsOrIsNotFoldHelper(OpTy op, bool equalIsTrue) {
393-
Type lhsType = op.self().getType();
394-
Type rhsType = op.obj().getType();
393+
Value lhs = op.self();
394+
Value rhs = op.obj();
395395

396-
// If either type is a NoneType, make it be the lhsType.
397-
if (rhsType.template isa<Torch::NoneType>())
398-
std::swap(lhsType, rhsType);
396+
// If either value is typed NoneType, make it be the lhs.
397+
if (rhs.getType().template isa<Torch::NoneType>())
398+
std::swap(lhs, rhs);
399+
400+
if (rhs.getType().template isa<Torch::OptionalType>())
401+
if (auto derefine = rhs.getDefiningOp<Torch::DerefineOp>())
402+
rhs = derefine.operand();
403+
404+
Type lhsType = lhs.getType();
405+
Type rhsType = rhs.getType();
399406
// TODO: Implement and use subtype infra for this.
400407
// If neither type is a subtype of the other, then the result is false.
401408
if (lhsType.template isa<Torch::NoneType>() &&
@@ -986,6 +993,14 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
986993
return isValidSubtype(outputs[0], inputs[0]);
987994
}
988995

996+
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
997+
if (auto derefineOp = x().getDefiningOp<Torch::DerefineOp>()) {
998+
if (derefineOp.operand().getType() == getType())
999+
return derefineOp.operand();
1000+
}
1001+
return nullptr;
1002+
}
1003+
9891004
//===----------------------------------------------------------------------===//
9901005
// Aten__Getitem__TOp
9911006
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def emit(key, **kwargs):
416416
emit("prim::max.int : (int, int) -> (int)")
417417
emit("prim::RaiseException : (str, str?) -> ()")
418418
emit("prim::Uninitialized : () -> (Any)", traits=["NoSideEffect"])
419-
emit("prim::unchecked_cast : (t) -> (t)",
419+
emit("prim::unchecked_cast : (t) -> (t)", has_folder=True,
420420
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
421421
emit("prim::Print : (...) -> ()")
422422
emit("prim::tolist : (...) -> (...)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ func @torch.aten.__is__(%arg0: !torch.list<!torch.int>, %arg1: !torch.none) -> !
88
return %0 : !torch.bool
99
}
1010

11+
// CHECK-LABEL: func @torch.aten.__is__$derefine_is_none
12+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
13+
// CHECK: return %[[FALSE]] : !torch.bool
14+
func @torch.aten.__is__$derefine_is_none(%arg0: !torch.list<!torch.int>, %arg1: !torch.none) -> !torch.bool {
15+
%0 = torch.derefine %arg0 : !torch.list<!torch.int> to !torch.optional<!torch.list<!torch.int>>
16+
%1 = torch.aten.__is__ %0, %arg1 : !torch.optional<!torch.list<!torch.int>>, !torch.none -> !torch.bool
17+
return %1 : !torch.bool
18+
}
19+
1120
// CHECK-LABEL: func @torch.aten.__is__$none_is_none
1221
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
1322
// CHECK: return %[[TRUE]] : !torch.bool
@@ -644,6 +653,13 @@ func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor,
644653
return %1 : !torch.tensor
645654
}
646655

656+
// CHECK-LABEL: func @torch.prim.unchecked_cast$derefine
657+
// CHECK-next: return %arg0 : !torch.list<!torch.int>
658+
func @torch.prim.unchecked_cast$derefine(%arg0: !torch.list<!torch.int>) -> !torch.list<!torch.int> {
659+
%0 = torch.derefine %arg0 : !torch.list<!torch.int> to !torch.optional<!torch.list<!torch.int>>
660+
%1 = torch.prim.unchecked_cast %0 : !torch.optional<!torch.list<!torch.int>> -> !torch.list<!torch.int>
661+
return %1 : !torch.list<!torch.int>
662+
}
647663

648664
// CHECK-LABEL: func @torch.aten.Int.Tensor(
649665
// CHECK-SAME: %[[NUM:.*]]: !torch.int) -> !torch.int {

0 commit comments

Comments
 (0)