@@ -863,6 +863,25 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
863863 }
864864
865865 // -----
866+ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
867+ --- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
868+ +++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
869+ @@ -1810,6 +1810,15 @@
870+ return %0 : tensor<2x4x1x5xf32>
871+ }
872+
873+ + // CHECK-LABEL: @transpose_of_transpose
874+ + func.func @transpose_of_transpose(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
875+ + %0 = stablehlo.transpose %arg0, dims = [3,2,1,0] : (tensor<1x2x3x4xf32>) -> tensor<4x3x2x1xf32>
876+ + %1 = stablehlo.transpose %0, dims = [3,2,1,0] : (tensor<4x3x2x1xf32>) -> tensor<1x2x3x4xf32>
877+ + // CHECK-NOT: stablehlo.transpose
878+ + // CHECK: return %arg0
879+ + return %1 : tensor<1x2x3x4xf32>
880+ + }
881+ +
882+ // -----
883+
884+ ////////
866885diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
867886--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
868887+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
@@ -1541,4 +1560,47 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
15411560 permutation[dims[i]] = i;
15421561 }
15431562 return b.getDenseI64ArrayAttr(permutation);
1563+ @@ -1308,6 +1308,17 @@
1564+ //////////////////////////////////
1565+ // TransposeOp
1566+ /////////////////////////////////
1567+ +
1568+ + DenseI64ArrayAttr getMergedTransposePermutation(OpBuilder& b,
1569+ + ArrayRef<int64_t> childPerm,
1570+ + ArrayRef<int64_t> parentPerm) {
1571+ + SmallVector<int64_t> mergedPerm;
1572+ + mergedPerm.reserve(parentPerm.size());
1573+ + for (int64_t parentIdx : parentPerm) {
1574+ + mergedPerm.push_back(childPerm[parentIdx]);
1575+ + }
1576+ + return b.getDenseI64ArrayAttr(mergedPerm);
1577+ + }
1578+
1579+ // Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X)
1580+ struct TransposeIsReshape final : SimplifyOpRewritePattern<TransposeOp> {
1581+ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
1582+ --- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
1583+ +++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
1584+ @@ -119,6 +119,8 @@
1585+ def InvertBroadcastDims : NativeCodeCall<"getInvertedBroadcastDimensions($_builder, $0)">;
1586+
1587+ def MergeBroadcastDims : NativeCodeCall<"getMergedBroadcastDimensions($_builder, $0, $1)">;
1588+ +
1589+ + def MergePermutations : NativeCodeCall<"getMergedTransposePermutation($_builder, $0, $1)">;
1590+
1591+ def StableHLO_ConvertOpWithShape : NativeCodeCall<
1592+ "$_builder.create<stablehlo::ConvertOp>($_loc, $0.getType(), $1)">;
1593+ @@ -539,6 +541,12 @@
1594+ : Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims),
1595+ (replaceWithValue $lhs)>;
1596+
1597+ + // Pattern: transpose(transpose(X)) -> transpose(X)
1598+ + def TransposeOp_TransposeOfTranspose
1599+ + : Pat<(StableHLO_TransposeOp
1600+ + (StableHLO_TransposeOp $child, $child_dims), $dims),
1601+ + (StableHLO_TransposeOp $child, (MergePermutations $child_dims, $dims))>;
1602+ +
1603+ ////////
1604+ // GetTupleElementOp
1605+
15441606
0 commit comments