Commit 1259e8a
authored
Add Some Folders For Small Reshape Ops (llvm#3813)
### Changes
1. Folders for view-like ops: `aten.view`, `aten.flatten.using_ints`,
and `aten.unflatten.int`
2. Folder for transpose
3. Extended support for the `aten.slice.Tensor` op folder to include
negative strides.
### Motivation
The biggest motivation for this patch is to fold the extremely
convoluted ir that gets generated when exporting a pytorch model with an
`aten.pad` op to ONNX, then re-importing and lowering back to torch. For
example, the verbose output of the e2e test `PadModule_basic` with `-c
onnx`:
```mlir
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
%none = torch.constant.none
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64>
%2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64>
%3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64>
%4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64>
%5 = torch.operator "onnx.Reshape"(%3, %4) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64>
%6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64>
%11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64>
%12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%13 = torch.operator "onnx.Reshape"(%11, %12) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64>
%14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64>
%15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32>
%16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %16 : !torch.vtensor<[?,?,?,?],f32>
}
}
{-#
dialect_resources: {
builtin: {
_: "0x080000000400000000000000",
__1: "0x080000000000000000000000010000000000000002000000000000000300000000000000",
__2: "0x080000000000000000000000",
__3: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
__4: "0x080000000000000000000000",
__5: "0x08000000FFFFFFFFFFFFFFFF",
__6: "0x080000000100000000000080",
__7: "0x08000000FFFFFFFFFFFFFFFF",
__8: "0x08000000FFFFFFFFFFFFFFFF",
__9: "0x080000000000C03F"
}
}
#-}
```
Get's converted to the torch IR:
```mlir
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
%float1.500000e00 = torch.constant.float 1.500000e+00
%int-9223372036854775807 = torch.constant.int -9223372036854775807
%int-1 = torch.constant.int -1
%int7 = torch.constant.int 7
%int6 = torch.constant.int 6
%int5 = torch.constant.int 5
%int3 = torch.constant.int 3
%int8 = torch.constant.int 8
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%0 = torch.vtensor.literal(dense<[0, 1, 2, 3, 0, 0, 0, 0]> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
%1 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
%3 = torch.aten.slice.Tensor %2, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
%4 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
%5 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%6 = torch.aten.view %4, %5 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%7 = torch.aten.slice.Tensor %6, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
%9 = torch.aten.slice.Tensor %6, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %6, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
%13 = torch.aten.slice.Tensor %6, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
%15 = torch.aten.slice.Tensor %6, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[1],si64> -> !torch.int
%17 = torch.aten.slice.Tensor %6, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%18 = torch.aten.item %17 : !torch.vtensor<[1],si64> -> !torch.int
%19 = torch.aten.slice.Tensor %6, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%20 = torch.aten.item %19 : !torch.vtensor<[1],si64> -> !torch.int
%21 = torch.aten.slice.Tensor %6, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
%23 = torch.prim.ListConstruct %14, %22, %12, %20, %10, %18, %8, %16 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%24 = torch.aten.constant_pad_nd %arg0, %23, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32>
return %24 : !torch.vtensor<[?,?,?,?],f32>
}
}
```
***All of these operations are useless***. It is literally the result of
needing to reverse (and change the lexicographic order hierarchy of)
padding ints provided via torch vs. ONNX pad ops, which is then
subsequently UNDONE by our ONNX->Torch lowering (represented in the
ordering of the generated list construct).
With the added folders in this patch, the torch IR becomes:
```
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
%float1.500000e00 = torch.constant.float 1.500000e+00
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.constant_pad_nd %arg0, %0, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32>
return %1 : !torch.vtensor<[?,?,?,?],f32>
}
}
```1 parent d6feb21 commit 1259e8a
File tree
5 files changed
+200
-24
lines changed- include/torch-mlir/Dialect/Torch/IR
- lib/Dialect/Torch/IR
- projects/pt1
- e2e_testing
- python/torch_mlir/jit_ir_importer/build_tools
- test/Dialect/Torch
5 files changed
+200
-24
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8080 | 8080 | | |
8081 | 8081 | | |
8082 | 8082 | | |
| 8083 | + | |
8083 | 8084 | | |
8084 | 8085 | | |
8085 | 8086 | | |
| |||
9672 | 9673 | | |
9673 | 9674 | | |
9674 | 9675 | | |
| 9676 | + | |
9675 | 9677 | | |
9676 | 9678 | | |
9677 | 9679 | | |
| |||
9696 | 9698 | | |
9697 | 9699 | | |
9698 | 9700 | | |
| 9701 | + | |
9699 | 9702 | | |
9700 | 9703 | | |
9701 | 9704 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
33 | 51 | | |
34 | 52 | | |
35 | 53 | | |
| |||
1049 | 1067 | | |
1050 | 1068 | | |
1051 | 1069 | | |
| 1070 | + | |
| 1071 | + | |
1052 | 1072 | | |
1053 | 1073 | | |
1054 | 1074 | | |
| |||
2236 | 2256 | | |
2237 | 2257 | | |
2238 | 2258 | | |
| 2259 | + | |
| 2260 | + | |
| 2261 | + | |
| 2262 | + | |
| 2263 | + | |
| 2264 | + | |
| 2265 | + | |
| 2266 | + | |
2239 | 2267 | | |
2240 | 2268 | | |
2241 | 2269 | | |
2242 | 2270 | | |
| 2271 | + | |
| 2272 | + | |
| 2273 | + | |
| 2274 | + | |
2243 | 2275 | | |
2244 | 2276 | | |
2245 | 2277 | | |
| |||
3722 | 3754 | | |
3723 | 3755 | | |
3724 | 3756 | | |
| 3757 | + | |
| 3758 | + | |
| 3759 | + | |
| 3760 | + | |
| 3761 | + | |
| 3762 | + | |
| 3763 | + | |
| 3764 | + | |
| 3765 | + | |
| 3766 | + | |
| 3767 | + | |
| 3768 | + | |
| 3769 | + | |
| 3770 | + | |
| 3771 | + | |
| 3772 | + | |
| 3773 | + | |
| 3774 | + | |
| 3775 | + | |
| 3776 | + | |
| 3777 | + | |
| 3778 | + | |
| 3779 | + | |
| 3780 | + | |
| 3781 | + | |
| 3782 | + | |
| 3783 | + | |
| 3784 | + | |
| 3785 | + | |
| 3786 | + | |
| 3787 | + | |
| 3788 | + | |
| 3789 | + | |
| 3790 | + | |
| 3791 | + | |
| 3792 | + | |
| 3793 | + | |
| 3794 | + | |
| 3795 | + | |
| 3796 | + | |
| 3797 | + | |
| 3798 | + | |
| 3799 | + | |
| 3800 | + | |
| 3801 | + | |
| 3802 | + | |
| 3803 | + | |
| 3804 | + | |
| 3805 | + | |
| 3806 | + | |
| 3807 | + | |
| 3808 | + | |
| 3809 | + | |
| 3810 | + | |
| 3811 | + | |
| 3812 | + | |
| 3813 | + | |
| 3814 | + | |
| 3815 | + | |
| 3816 | + | |
| 3817 | + | |
| 3818 | + | |
| 3819 | + | |
3725 | 3820 | | |
3726 | 3821 | | |
3727 | 3822 | | |
| |||
3898 | 3993 | | |
3899 | 3994 | | |
3900 | 3995 | | |
3901 | | - | |
| 3996 | + | |
3902 | 3997 | | |
3903 | 3998 | | |
3904 | 3999 | | |
3905 | | - | |
3906 | | - | |
3907 | 4000 | | |
3908 | 4001 | | |
| 4002 | + | |
3909 | 4003 | | |
| 4004 | + | |
| 4005 | + | |
| 4006 | + | |
| 4007 | + | |
3910 | 4008 | | |
3911 | 4009 | | |
3912 | 4010 | | |
| |||
3919 | 4017 | | |
3920 | 4018 | | |
3921 | 4019 | | |
3922 | | - | |
3923 | | - | |
3924 | | - | |
3925 | | - | |
| 4020 | + | |
| 4021 | + | |
| 4022 | + | |
| 4023 | + | |
| 4024 | + | |
| 4025 | + | |
| 4026 | + | |
| 4027 | + | |
| 4028 | + | |
| 4029 | + | |
| 4030 | + | |
| 4031 | + | |
| 4032 | + | |
| 4033 | + | |
| 4034 | + | |
3926 | 4035 | | |
3927 | 4036 | | |
3928 | 4037 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2677 | 2677 | | |
2678 | 2678 | | |
2679 | 2679 | | |
2680 | | - | |
2681 | | - | |
2682 | | - | |
2683 | | - | |
2684 | | - | |
2685 | | - | |
2686 | | - | |
2687 | | - | |
2688 | | - | |
2689 | | - | |
2690 | | - | |
2691 | | - | |
2692 | | - | |
2693 | | - | |
2694 | 2680 | | |
2695 | 2681 | | |
2696 | 2682 | | |
| |||
Lines changed: 5 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
684 | 684 | | |
685 | 685 | | |
686 | 686 | | |
687 | | - | |
| 687 | + | |
688 | 688 | | |
689 | 689 | | |
690 | 690 | | |
| |||
769 | 769 | | |
770 | 770 | | |
771 | 771 | | |
772 | | - | |
| 772 | + | |
773 | 773 | | |
774 | | - | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
775 | 777 | | |
776 | 778 | | |
777 | 779 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1682 | 1682 | | |
1683 | 1683 | | |
1684 | 1684 | | |
| 1685 | + | |
| 1686 | + | |
| 1687 | + | |
| 1688 | + | |
| 1689 | + | |
| 1690 | + | |
| 1691 | + | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
| 1695 | + | |
| 1696 | + | |
| 1697 | + | |
| 1698 | + | |
| 1699 | + | |
| 1700 | + | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| 1705 | + | |
| 1706 | + | |
| 1707 | + | |
| 1708 | + | |
| 1709 | + | |
| 1710 | + | |
| 1711 | + | |
| 1712 | + | |
| 1713 | + | |
| 1714 | + | |
| 1715 | + | |
| 1716 | + | |
| 1717 | + | |
| 1718 | + | |
| 1719 | + | |
| 1720 | + | |
| 1721 | + | |
| 1722 | + | |
| 1723 | + | |
| 1724 | + | |
| 1725 | + | |
| 1726 | + | |
| 1727 | + | |
| 1728 | + | |
| 1729 | + | |
| 1730 | + | |
| 1731 | + | |
| 1732 | + | |
| 1733 | + | |
| 1734 | + | |
| 1735 | + | |
| 1736 | + | |
| 1737 | + | |
| 1738 | + | |
| 1739 | + | |
| 1740 | + | |
| 1741 | + | |
| 1742 | + | |
| 1743 | + | |
| 1744 | + | |
| 1745 | + | |
| 1746 | + | |
| 1747 | + | |
| 1748 | + | |
| 1749 | + | |
| 1750 | + | |
| 1751 | + | |
| 1752 | + | |
| 1753 | + | |
| 1754 | + | |
| 1755 | + | |
| 1756 | + | |
| 1757 | + | |
| 1758 | + | |
| 1759 | + | |
| 1760 | + | |
1685 | 1761 | | |
1686 | 1762 | | |
1687 | 1763 | | |
| |||
0 commit comments