Skip to content

Commit 1259e8a

Browse files
authored
Add Some Folders For Small Reshape Ops (#3813)
### Changes 1. Folders for view-like ops: `aten.view`, `aten.flatten.using_ints`, and `aten.unflatten.int` 2. Folder for transpose 3. Extended support for the `aten.slice.Tensor` op folder to include negative strides. ### Motivation The biggest motivation for this patch is to fold the extremely convoluted ir that gets generated when exporting a pytorch model with an `aten.pad` op to ONNX, then re-importing and lowering back to torch. For example, the verbose output of the e2e test `PadModule_basic` with `-c onnx`: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %none = torch.constant.none %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> %5 = torch.operator "onnx.Reshape"(%3, %4) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %13 = torch.operator "onnx.Reshape"(%11, %12) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32> %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> return %16 : !torch.vtensor<[?,?,?,?],f32> } } {-# dialect_resources: { builtin: { _: "0x080000000400000000000000", __1: "0x080000000000000000000000010000000000000002000000000000000300000000000000", __2: "0x080000000000000000000000", __3: "0x08000000FFFFFFFFFFFFFFFF0200000000000000", __4: "0x080000000000000000000000", __5: "0x08000000FFFFFFFFFFFFFFFF", __6: "0x080000000100000000000080", __7: "0x08000000FFFFFFFFFFFFFFFF", __8: "0x08000000FFFFFFFFFFFFFFFF", __9: "0x080000000000C03F" } } #-} ``` Get's converted to the torch IR: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %float1.500000e00 = torch.constant.float 1.500000e+00 %int-9223372036854775807 = torch.constant.int -9223372036854775807 %int-1 = torch.constant.int -1 %int7 = torch.constant.int 7 %int6 = torch.constant.int 6 %int5 = torch.constant.int 5 %int3 = torch.constant.int 3 %int8 = torch.constant.int 8 %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int4 = torch.constant.int 4 %int0 = torch.constant.int 0 %0 = torch.vtensor.literal(dense<[0, 1, 2, 3, 0, 0, 0, 0]> : tensor<8xsi64>) : !torch.vtensor<[8],si64> %1 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int> %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64> %3 = torch.aten.slice.Tensor %2, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> %4 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> %5 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int> %6 = torch.aten.view %4, %5 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64> %7 = torch.aten.slice.Tensor %6, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int %9 = torch.aten.slice.Tensor %6, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int %11 = torch.aten.slice.Tensor %6, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int %13 = torch.aten.slice.Tensor %6, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int %15 = torch.aten.slice.Tensor %6, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %16 = torch.aten.item %15 : !torch.vtensor<[1],si64> -> !torch.int %17 = torch.aten.slice.Tensor %6, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %18 = torch.aten.item %17 : !torch.vtensor<[1],si64> -> !torch.int %19 = torch.aten.slice.Tensor %6, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %20 = torch.aten.item %19 : !torch.vtensor<[1],si64> -> !torch.int %21 = torch.aten.slice.Tensor %6, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int %23 = torch.prim.ListConstruct %14, %22, %12, %20, %10, %18, %8, %16 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> %24 = torch.aten.constant_pad_nd %arg0, %23, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32> return %24 : !torch.vtensor<[?,?,?,?],f32> } } ``` ***All of these operations are useless***. It is literally the result of needing to reverse (and change the lexicographic order hierarchy of) padding ints provided via torch vs. ONNX pad ops, which is then subsequently UNDONE by our ONNX->Torch lowering (represented in the ordering of the generated list construct). With the added folders in this patch, the torch IR becomes: ``` module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} { %float1.500000e00 = torch.constant.float 1.500000e+00 %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %int1 = torch.constant.int 1 %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> %1 = torch.aten.constant_pad_nd %arg0, %0, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32> return %1 : !torch.vtensor<[?,?,?,?],f32> } } ```
1 parent d6feb21 commit 1259e8a

File tree

5 files changed

+200
-24
lines changed

5 files changed

+200
-24
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8080,6 +8080,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
80808080
printDefaultTorchOp(printer, *this, 3, 1);
80818081
}
80828082
}];
8083+
let hasFolder = 1;
80838084
}
80848085

80858086
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
@@ -9672,6 +9673,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
96729673
printDefaultTorchOp(printer, *this, 3, 1);
96739674
}
96749675
}];
9676+
let hasFolder = 1;
96759677
}
96769678

96779679
def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
@@ -9696,6 +9698,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
96969698
printDefaultTorchOp(printer, *this, 3, 1);
96979699
}
96989700
}];
9701+
let hasFolder = 1;
96999702
let hasCanonicalizer = 1;
97009703
}
97019704

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ using namespace mlir::torch::Torch;
3030
// Utilities
3131
//===----------------------------------------------------------------------===//
3232

33+
OpFoldResult genericViewLikeFold(Attribute self, Type resultType) {
34+
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(self);
35+
if (!selfAttr)
36+
return nullptr;
37+
38+
auto resultTy = dyn_cast_or_null<ValueTensorType>(resultType);
39+
if (!resultTy || !resultTy.areAllSizesKnown())
40+
return nullptr;
41+
42+
if (selfAttr.isSplat()) {
43+
return SplatElementsAttr::get(resultTy.toBuiltinTensor(),
44+
selfAttr.getSplatValue<Attribute>());
45+
}
46+
return DenseElementsAttr::get(
47+
resultTy.toBuiltinTensor(),
48+
llvm::to_vector(selfAttr.getValues<Attribute>()));
49+
}
50+
3351
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
3452
Location loc, Value value,
3553
Type desiredType,
@@ -1049,6 +1067,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
10491067
//===----------------------------------------------------------------------===//
10501068

10511069
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
1070+
if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType()))
1071+
return genericFold;
10521072
auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
10531073
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
10541074
return nullptr;
@@ -2236,10 +2256,22 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
22362256
});
22372257
}
22382258

2259+
//===----------------------------------------------------------------------===//
2260+
// AtenFlattenUsingIntsOp
2261+
//===----------------------------------------------------------------------===//
2262+
2263+
OpFoldResult AtenFlattenUsingIntsOp::fold(FoldAdaptor adaptor) {
2264+
return genericViewLikeFold(adaptor.getSelf(), getType());
2265+
}
2266+
22392267
//===----------------------------------------------------------------------===//
22402268
// AtenUnflattenIntOp
22412269
//===----------------------------------------------------------------------===//
22422270

2271+
OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) {
2272+
return genericViewLikeFold(adaptor.getSelf(), getType());
2273+
}
2274+
22432275
void AtenUnflattenIntOp::getCanonicalizationPatterns(
22442276
RewritePatternSet &patterns, MLIRContext *context) {
22452277
// if there are only two sizes and one of them is statically 1, then convert
@@ -3722,6 +3754,69 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
37223754
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
37233755
}
37243756

3757+
//===----------------------------------------------------------------------===//
3758+
// AtenTransposeIntOp
3759+
//===----------------------------------------------------------------------===//
3760+
3761+
OpFoldResult AtenTransposeIntOp::fold(FoldAdaptor adaptor) {
3762+
// first check for no-op
3763+
IntegerAttr dim0 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim0());
3764+
IntegerAttr dim1 = dyn_cast_or_null<IntegerAttr>(adaptor.getDim1());
3765+
if (!dim0 || !dim1)
3766+
return nullptr;
3767+
int64_t _dim0 = dim0.getValue().getSExtValue();
3768+
int64_t _dim1 = dim1.getValue().getSExtValue();
3769+
auto selfTy = dyn_cast<ValueTensorType>(getSelf().getType());
3770+
if (!selfTy || !selfTy.hasSizes())
3771+
return nullptr;
3772+
int64_t rank = selfTy.getSizes().size();
3773+
_dim0 = toPositiveDim(_dim0, rank);
3774+
_dim1 = toPositiveDim(_dim1, rank);
3775+
if (!isValidDim(_dim0, rank) || !isValidDim(_dim1, rank))
3776+
return nullptr;
3777+
// if dims are the same, return self
3778+
if (_dim0 == _dim1)
3779+
return getSelf();
3780+
3781+
// We set a maximum folding size of 16. This is a reasonable upper limit
3782+
// for shape computations.
3783+
constexpr int64_t kMaxFoldSize = 16;
3784+
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
3785+
if (!self || self.getNumElements() > kMaxFoldSize)
3786+
return nullptr;
3787+
auto resultTy = dyn_cast<ValueTensorType>(getType());
3788+
if (!selfTy || !resultTy || !selfTy.areAllSizesKnown())
3789+
return nullptr;
3790+
if (self.isSplat())
3791+
return SplatElementsAttr::get(resultTy.toBuiltinTensor(),
3792+
self.getSplatValue<Attribute>());
3793+
3794+
// TODO: add support for rank != 2
3795+
if (rank != 2)
3796+
return nullptr;
3797+
3798+
ArrayRef<int64_t> sizes = selfTy.getSizes();
3799+
auto values = llvm::to_vector(self.getValues<Attribute>());
3800+
// reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0],
3801+
// i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])].
3802+
// e.g., Self size = [4,2]; Trans size = [2,4].
3803+
// reindex(i) = (i % 4)*2 + (i // 4) .
3804+
// i = 0 -> Trans[0,0] -> Self[0,0] -> 0 .
3805+
// i = 1 -> Trans[0,1] -> Self[1,0] -> 2 .
3806+
// i = 2 -> Trans[0,2] -> Self[2,0] -> 4 .
3807+
// i = 3 -> Trans[0,3] -> Self[3,0] -> 6 .
3808+
// i = 4 -> Trans[1,0] -> Self[0,1] -> 1 .
3809+
// i = 5 -> Trans[1,1] -> Self[1,1] -> 3 .
3810+
auto reindex = [&](int64_t i) {
3811+
return (i % sizes[0]) * sizes[1] + (i / sizes[0]);
3812+
};
3813+
SmallVector<Attribute> reordered;
3814+
for (int64_t i = 0; i < self.getNumElements(); i++) {
3815+
reordered.push_back(values[reindex(i)]);
3816+
}
3817+
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), reordered);
3818+
}
3819+
37253820
//===----------------------------------------------------------------------===//
37263821
// AtenCatOp
37273822
//===----------------------------------------------------------------------===//
@@ -3898,15 +3993,18 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
38983993
// Fold the slice if the output tensor is relatively small, currently
38993994
// coded to 16:
39003995
constexpr int64_t kMaxFold = 16;
3901-
if (input && start && step && dim && count <= kMaxFold) {
3996+
if (input && start && step && dim && end && count <= kMaxFold) {
39023997
int64_t begin = start.getValue().getSExtValue();
39033998
int64_t limit = end.getValue().getSExtValue();
39043999
int64_t stride = step.getValue().getSExtValue();
3905-
if (stride < 1)
3906-
return nullptr;
39074000
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
39084001
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
4002+
limit = limit < 0 ? -1 : limit;
39094003
limit = std::min(limit, inType.getSizes()[dimInt]);
4004+
bool validIterArgs =
4005+
(stride > 0 && begin < limit) || (stride < 0 && begin > limit);
4006+
assert(validIterArgs &&
4007+
"aten.slice.Tensor iteration args are statically invalid.");
39104008

39114009
int64_t inputRank = inType.getSizes().size();
39124010
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
@@ -3919,10 +4017,21 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
39194017
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
39204018
if (currDim >= inputRank)
39214019
return;
3922-
size_t _begin = (currDim == dimInt) ? begin : 0;
3923-
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
3924-
size_t _stride = (currDim == dimInt) ? stride : 1;
3925-
for (size_t i = _begin; i < _limit; i += _stride) {
4020+
int64_t _stride = (currDim == dimInt) ? stride : 1;
4021+
int64_t _begin = (currDim == dimInt) ? begin : 0;
4022+
int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
4023+
// ensure that the limit is reached exactly (even with negative strides)
4024+
// E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11
4025+
// = 10 + (10-0) % 3 .
4026+
// E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 +
4027+
// (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 .
4028+
// Note: cpp uses true math remainder "n % d = least positive int, x, such
4029+
// that d divides (n - x)"
4030+
int64_t limit_rem = (_limit - _begin) % _stride;
4031+
limit_rem =
4032+
(_stride > 0 || limit_rem == 0) ? limit_rem : limit_rem - _stride;
4033+
_limit += limit_rem;
4034+
for (int64_t i = _begin; std::abs(_limit - i) > 0; i += _stride) {
39264035
if (currDim == inputRank - 1) {
39274036
values.push_back(input.getValues<Attribute>()[currOffset + i]);
39284037
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,20 +2677,6 @@
26772677
"MultinomialModule2D_basic",
26782678
"MultinomialModule2D_F32",
26792679
"PixelShuffleModuleStaticRank4Float32_basic",
2680-
"ReflectionPad1dModule2dInput_Right",
2681-
"ReflectionPad1dModule2dInput_basic",
2682-
"ReflectionPad1dModule3dInput_Left",
2683-
"ReflectionPad1dModule3dInput_basic",
2684-
"ReflectionPad2dModule_Bottom",
2685-
"ReflectionPad2dModule_Left",
2686-
"ReflectionPad2dModule_Right",
2687-
"ReflectionPad2dModule_Top",
2688-
"ReflectionPad2dModule_basic",
2689-
"ReplicationPad2dModule_basic",
2690-
"ReplicationPad2dModule_bottom0",
2691-
"ReplicationPad2dModule_left0",
2692-
"ReplicationPad2dModule_right0",
2693-
"ReplicationPad2dModule_top0",
26942680
"SliceCopyEndGreaterThanDimSize_Module_basic",
26952681
"SliceCopyNegative_Module_basic",
26962682
"SliceCopyNonZeroDim_Module_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def emit_with_mutating_variants(key, **kwargs):
684684
emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)")
685685
emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)")
686686
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
687-
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
687+
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True)
688688
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
689689
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
690690
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
@@ -769,9 +769,11 @@ def emit_with_mutating_variants(key, **kwargs):
769769
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
770770
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
771771
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
772-
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
772+
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)", has_folder=True)
773773
emit(
774-
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True
774+
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)",
775+
has_canonicalizer=True,
776+
has_folder=True,
775777
)
776778
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
777779
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)

test/Dialect/Torch/canonicalize.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,82 @@ func.func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[?
16821682
return %1 : !torch.tensor<[?],f32>
16831683
}
16841684

1685+
// CHECK-LABEL: func.func @torch.aten.view$fold_splat(
1686+
// CHECK: %[[SPLAT:.*]] = torch.vtensor.literal(dense<2> : tensor<2x4x1xsi64>) : !torch.vtensor<[2,4,1],si64>
1687+
// CHECK: return %[[SPLAT]] : !torch.vtensor<[2,4,1],si64>
1688+
func.func @torch.aten.view$fold_splat() -> !torch.vtensor<[2,4,1],si64> {
1689+
%int4 = torch.constant.int 4
1690+
%int2 = torch.constant.int 2
1691+
%int1 = torch.constant.int 1
1692+
%0 = torch.vtensor.literal(dense<2> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
1693+
%1 = torch.prim.ListConstruct %int2, %int4, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
1694+
%2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[2,4,1],si64>
1695+
return %2 : !torch.vtensor<[2,4,1],si64>
1696+
}
1697+
1698+
// CHECK-LABEL: func.func @torch.aten.view$fold_literal(
1699+
// CHECK: %[[LITERAL:.*]] = torch.vtensor.literal(dense<[
1700+
// CHECK-SAME: [
1701+
// CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7]]]> : tensor<1x4x2xsi64>) : !torch.vtensor<[1,4,2],si64>
1702+
// CHECK: return %[[LITERAL]] : !torch.vtensor<[1,4,2],si64>
1703+
func.func @torch.aten.view$fold_literal() -> !torch.vtensor<[1,4,2],si64> {
1704+
%int4 = torch.constant.int 4
1705+
%int2 = torch.constant.int 2
1706+
%int1 = torch.constant.int 1
1707+
%0 = torch.vtensor.literal(dense<[0,1,2,3,4,5,6,7]> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
1708+
%1 = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
1709+
%2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[1,4,2],si64>
1710+
return %2 : !torch.vtensor<[1,4,2],si64>
1711+
}
1712+
1713+
// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_literal(
1714+
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[
1715+
// CHECK-SAME: [0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xsi64>) : !torch.vtensor<[2,4],si64>
1716+
// CHECK: return %[[LIT]] : !torch.vtensor<[2,4],si64>
1717+
func.func @torch.aten.transpose.int$fold_literal() -> !torch.vtensor<[2,4],si64> {
1718+
%int-1 = torch.constant.int -1
1719+
%int0 = torch.constant.int 0
1720+
%0 = torch.vtensor.literal(dense<[[0,1],[2,3],[4,5],[6,7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
1721+
%1 = torch.aten.transpose.int %0, %int-1, %int0 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4], si64>
1722+
return %1 : !torch.vtensor<[2,4],si64>
1723+
}
1724+
1725+
// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_noop(
1726+
// CHECK: return %arg0 : !torch.vtensor<[?,?,?,?],f32>
1727+
func.func @torch.aten.transpose.int$fold_noop(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
1728+
%int-1 = torch.constant.int -1
1729+
%int3 = torch.constant.int 3
1730+
%0 = torch.aten.transpose.int %arg0, %int-1, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
1731+
return %0 : !torch.vtensor<[?,?,?,?],f32>
1732+
}
1733+
1734+
// CHECK-LABEL: func.func @torch.aten.slice.Tensor$flip_slice_fold(
1735+
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[
1736+
// CHECK-SAME: [6, 7], [4, 5], [2, 3], [0, 1]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
1737+
// CHECK: return %[[LIT]] : !torch.vtensor<[4,2],si64>
1738+
func.func @torch.aten.slice.Tensor$flip_slice_fold() -> !torch.vtensor<[4,2],si64> {
1739+
%int-9223372036854775807 = torch.constant.int -9223372036854775807
1740+
%int-1 = torch.constant.int -1
1741+
%int0 = torch.constant.int 0
1742+
%0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
1743+
%1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
1744+
return %1 : !torch.vtensor<[4,2],si64>
1745+
}
1746+
1747+
// CHECK-LABEL: func.func @torch.aten.slice.Tensor$negative_two_stride_fold(
1748+
// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[
1749+
// CHECK-SAME: [6, 7], [2, 3]]> : tensor<2x2xsi64>) : !torch.vtensor<[2,2],si64>
1750+
// CHECK: return %[[LIT]] : !torch.vtensor<[2,2],si64>
1751+
func.func @torch.aten.slice.Tensor$negative_two_stride_fold() -> !torch.vtensor<[2,2],si64> {
1752+
%int-5 = torch.constant.int -5
1753+
%int-1 = torch.constant.int -1
1754+
%int-2 = torch.constant.int -2
1755+
%int0 = torch.constant.int 0
1756+
%0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
1757+
%1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-5, %int-2 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],si64>
1758+
return %1 : !torch.vtensor<[2,2],si64>
1759+
}
1760+
16851761
// CHECK-LABEL: func.func @torch.aten.div.float$fold_zero_dividend(
16861762
// CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00
16871763
// CHECK: return %[[CST0]] : !torch.float

0 commit comments

Comments
 (0)