Skip to content

Commit a83004c

Browse files
author
Gaurav Shukla
committed
[TORCH][MLIR] Fold trivial cases of aten.to.dtype and aten.view op
- It folds `aten.to.dtype` when the input tensor type and result type are exactly same. - It folds `aten.view` when the rank of both the input tensor type and result type is unity. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 9e1ecf2 commit a83004c

File tree

7 files changed

+108
-4
lines changed

7 files changed

+108
-4
lines changed

e2e_testing/torchscript/elementwise.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,22 @@ def forward(self, x):
865865
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
866866
module.forward(tu.rand(3, 5))
867867

868+
class ElementwiseToDtypeIdentityModule(torch.nn.Module):
869+
def __init__(self):
870+
super().__init__()
871+
872+
@export
873+
@annotate_args([
874+
None,
875+
([-1, -1], torch.float32, True)
876+
])
877+
def forward(self, x):
878+
return x.to(torch.float32, False, False)
879+
880+
@register_test_case(module_factory=lambda: ElementwiseToDtypeIdentityModule())
881+
def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
882+
module.forward(tu.rand(3, 5))
883+
868884
class ElementwiseLog2Module(torch.nn.Module):
869885
def __init__(self):
870886
super().__init__()

e2e_testing/torchscript/view.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
1010

1111
# ==============================================================================
12+
1213
class ViewExpandModule(torch.nn.Module):
1314
def __init__(self):
1415
super().__init__()
@@ -45,8 +46,8 @@ def forward(self, a):
4546
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
4647
module.forward(tu.rand(2, 4, 30, 384))
4748

48-
4949
# ==============================================================================
50+
5051
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
5152
def __init__(self):
5253
super().__init__()
@@ -65,6 +66,7 @@ def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
6566
module.forward(tu.rand(2, 4, 384))
6667

6768
# ==============================================================================
69+
6870
class ViewCollapseModule(torch.nn.Module):
6971
def __init__(self):
7072
super().__init__()
@@ -82,8 +84,8 @@ def forward(self, a):
8284
def ViewCollapseModule_basic(module, tu: TestUtils):
8385
module.forward(tu.rand(2, 4))
8486

85-
8687
# ==============================================================================
88+
8789
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
8890
def __init__(self):
8991
super().__init__()
@@ -102,3 +104,22 @@ def forward(self, a, b, c):
102104
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
103105
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
104106
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
107+
108+
# ==============================================================================
109+
110+
class View1DFoldModule(torch.nn.Module):
111+
def __init__(self):
112+
super().__init__()
113+
114+
@export
115+
@annotate_args([
116+
None,
117+
([-1], torch.float32, True),
118+
])
119+
120+
def forward(self, a):
121+
return a.view(-1)
122+
123+
@register_test_case(module_factory=lambda: View1DFoldModule())
124+
def View1DFoldModule_basic(module, tu: TestUtils):
125+
module.forward(tu.rand(32))

e2e_testing/torchscript/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,6 @@
4343
"SqueezeModule_allUnitDim",
4444
"TModuleRank1_basic",
4545
"TModuleRank0_basic",
46+
"ElementwiseToDtypeIdentityModule_basic",
47+
"View1DFoldModule_basic",
4648
}

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,7 @@ def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
23992399
AnyTorchTensorType:$result
24002400
);
24012401
let assemblyFormat = "$self `,` $dtype `,` $non_blocking `,` $copy `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($non_blocking) `,` type($copy) `,` type($memory_format) `->` type($result)";
2402+
let hasFolder = 1;
24022403
}
24032404

24042405
def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [
@@ -2462,6 +2463,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
24622463
AnyTorchTensorType:$result
24632464
);
24642465
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
2466+
let hasFolder = 1;
24652467
}
24662468

24672469
def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,48 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
474474
return nullptr;
475475
}
476476

477+
//===----------------------------------------------------------------------===//
478+
// AtenToDtypeOp
479+
//===----------------------------------------------------------------------===//
480+
481+
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
482+
bool nonBlocking, copyArg;
483+
// The non_blocking arg must be `False`.
484+
if (!matchPattern(non_blocking(), m_TorchConstantBool(&nonBlocking)) ||
485+
nonBlocking)
486+
return nullptr;
487+
// The copy arg must be `False`.
488+
if (!matchPattern(copy(), m_TorchConstantBool(&copyArg)) || copyArg)
489+
return nullptr;
490+
// The memory_format arg must be `none`.
491+
if (!memory_format().getType().isa<Torch::NoneType>())
492+
return nullptr;
493+
494+
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
495+
if (!inputType || !inputType.hasSizes())
496+
return nullptr;
497+
auto resType = getType().dyn_cast<BaseTensorType>();
498+
if (!resType || !resType.hasSizes() || inputType != resType)
499+
return nullptr;
500+
// Fold when both the input tensor and result are of the same type.
501+
return getOperand(0);
502+
}
503+
504+
//===----------------------------------------------------------------------===//
505+
// AtenViewOp
506+
//===----------------------------------------------------------------------===//
507+
508+
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
509+
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
510+
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
511+
return nullptr;
512+
auto resType = getType().dyn_cast<BaseTensorType>();
513+
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
514+
return nullptr;
515+
// Fold when both the input tensor and result are unity rank tensors.
516+
return getOperand(0);
517+
}
518+
477519
//===----------------------------------------------------------------------===//
478520
// AtenDimOp
479521
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,11 @@ def emit_with_mutating_variants(key, **kwargs):
580580
emit("aten::stack : (Tensor[], int) -> (Tensor)")
581581
emit("aten::sum : (Tensor, int?) -> (Tensor)")
582582
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
583-
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)")
583+
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
584584
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
585585
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
586586
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
587-
emit("aten::view : (Tensor, int[]) -> (Tensor)")
587+
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
588588
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
589589
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
590590
emit("aten::len.Tensor : (Tensor) -> (int)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,24 @@ func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.t
622622
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
623623
return %0 : !torch.tensor<[],f32>
624624
}
625+
626+
// CHECK-LABEL: func @torch.aten.to.dtype$same_dtype(
627+
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
628+
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
629+
func @torch.aten.to.dtype$same_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
630+
%none = torch.constant.none
631+
%false = torch.constant.bool false
632+
%int6 = torch.constant.int 6
633+
%0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
634+
return %0 : !torch.tensor<[?,?],f32>
635+
}
636+
637+
// CHECK-LABEL: func @torch.aten.view$1D(
638+
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
639+
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?],f32>
640+
func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?],f32> {
641+
%int-1 = torch.constant.int -1
642+
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<!torch.int>
643+
%1 = torch.aten.view %arg0, %0 : !torch.tensor<[?],f32>, !torch.list<!torch.int> -> !torch.tensor<[?],f32>
644+
return %1 : !torch.tensor<[?],f32>
645+
}

0 commit comments

Comments
 (0)