Skip to content

Commit e5ca0d8

Browse files
committed
catch additional foldable case
1 parent 6ed05ed commit e5ca0d8

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5573,13 +5573,11 @@ LogicalResult ShapeCastOp::verify() {
55735573
return success();
55745574
}
55755575

5576-
namespace {
5577-
55785576
/// Return true if `transpose` does not permute a pair of non-unit dims.
55795577
/// By `order preserving` we mean that the flattened versions of the input and
55805578
/// output vectors are (numerically) identical. In other words `transpose` is
55815579
/// effectively a shape cast.
5582-
bool isOrderPreserving(TransposeOp transpose) {
5580+
static bool isOrderPreserving(TransposeOp transpose) {
55835581
ArrayRef<int64_t> permutation = transpose.getPermutation();
55845582
VectorType sourceType = transpose.getSourceVectorType();
55855583
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5599,8 +5597,6 @@ bool isOrderPreserving(TransposeOp transpose) {
55995597
return true;
56005598
}
56015599

5602-
} // namespace
5603-
56045600
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56055601

56065602
VectorType resultType = getType();
@@ -5997,18 +5993,22 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
59975993
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
59985994
return ub::PoisonAttr::get(getContext());
59995995

6000-
// Eliminate identity transpose ops. This happens when the dimensions of the
6001-
// input vector remain in their original order after the transpose operation.
6002-
ArrayRef<int64_t> perm = getPermutation();
6003-
6004-
// Check if the permutation of the dimensions contains sequential values:
6005-
// {0, 1, 2, ...}.
6006-
for (int64_t i = 0, e = perm.size(); i < e; i++) {
6007-
if (perm[i] != i)
6008-
return {};
6009-
}
5996+
// Eliminate identity transposes, and more generally any transposes that
5997+
// preserves the shape without permuting elements.
5998+
//
5999+
// Examples of what to fold:
6000+
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6001+
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6002+
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6003+
//
6004+
// Example of what NOT to fold:
6005+
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6006+
//
6007+
if (getSourceVectorType() == getResultVectorType() &&
6008+
isOrderPreserving(*this))
6009+
return getVector();
60106010

6011-
return getVector();
6011+
return {};
60126012
}
60136013

60146014
LogicalResult vector::TransposeOp::verify() {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -450,28 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
450450

451451
// -----
452452

453-
// CHECK-LABEL: transpose_1D_identity
454-
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
455-
func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
456-
// CHECK-NOT: transpose
457-
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
458-
// CHECK-NEXT: return [[ARG]]
459-
return %0 : vector<4xf32>
460-
}
461-
462-
// -----
463-
464-
// CHECK-LABEL: transpose_2D_identity
465-
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
466-
func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
467-
// CHECK-NOT: transpose
468-
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
469-
// CHECK-NEXT: return [[ARG]]
470-
return %0 : vector<4x3xf32>
471-
}
472-
473-
// -----
474-
475453
// CHECK-LABEL: transpose_3D_identity
476454
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
477455
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,36 @@ func.func @negative_transpose_of_shape_cast(%arg : vector<6xi8>) -> vector<2x3xi
248248
%1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
249249
return %1 : vector<2x3xi8>
250250
}
251+
252+
// -----
253+
254+
// Test of transpose folding
255+
// CHECK-LABEL: transpose_1D_identity
256+
// CHECK-SAME: [[ARG:%.*]]: vector<4xf32>
257+
// CHECK-NEXT: return [[ARG]]
258+
func.func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {
259+
%0 = vector.transpose %arg, [0] : vector<4xf32> to vector<4xf32>
260+
return %0 : vector<4xf32>
261+
}
262+
263+
// -----
264+
265+
// Test of transpose folding
266+
// CHECK-LABEL: transpose_2D_identity
267+
// CHECK-SAME: [[ARG:%.*]]: vector<4x3xf32>
268+
// CHECK-NEXT: return [[ARG]]
269+
func.func @transpose_2D_identity(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
270+
%0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
271+
return %0 : vector<4x3xf32>
272+
}
273+
274+
// -----
275+
276+
// Test of transpose folding
277+
// CHECK-LABEL: transpose_shape_and_order_preserving
278+
// CHECK-SAME: [[ARG:%.*]]: vector<6x1x1x4xi8>
279+
// CHECK-NEXT: return [[ARG]]
280+
func.func @transpose_shape_and_order_preserving(%arg : vector<6x1x1x4xi8>) -> vector<6x1x1x4xi8> {
281+
%0 = vector.transpose %arg, [0, 2, 1, 3] : vector<6x1x1x4xi8> to vector<6x1x1x4xi8>
282+
return %0 : vector<6x1x1x4xi8>
283+
}

0 commit comments

Comments
 (0)