Skip to content

Commit b14d379

Browse files
GleasonKGoogle-ML-Automation
authored andcommitted
[StableHLO] Add transpose simplification
PiperOrigin-RevId: 820804015
1 parent 5b41cc2 commit b14d379

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

third_party/stablehlo/temporary.patch

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,25 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_folder.ml
863863
}
864864

865865
// -----
866+
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
867+
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
868+
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
869+
@@ -1810,6 +1810,15 @@
870+
return %0 : tensor<2x4x1x5xf32>
871+
}
872+
873+
+// CHECK-LABEL: @transpose_of_transpose
874+
+func.func @transpose_of_transpose(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
875+
+ %0 = stablehlo.transpose %arg0, dims = [3,2,1,0] : (tensor<1x2x3x4xf32>) -> tensor<4x3x2x1xf32>
876+
+ %1 = stablehlo.transpose %0, dims = [3,2,1,0] : (tensor<4x3x2x1xf32>) -> tensor<1x2x3x4xf32>
877+
+ // CHECK-NOT: stablehlo.transpose
878+
+ // CHECK: return %arg0
879+
+ return %1 : tensor<1x2x3x4xf32>
880+
+}
881+
+
882+
// -----
883+
884+
////////
866885
diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
867886
--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
868887
+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
@@ -1541,4 +1560,47 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
15411560
permutation[dims[i]] = i;
15421561
}
15431562
return b.getDenseI64ArrayAttr(permutation);
1563+
@@ -1308,6 +1308,17 @@
1564+
//////////////////////////////////
1565+
// TransposeOp
1566+
/////////////////////////////////
1567+
+
1568+
+DenseI64ArrayAttr getMergedTransposePermutation(OpBuilder& b,
1569+
+ ArrayRef<int64_t> childPerm,
1570+
+ ArrayRef<int64_t> parentPerm) {
1571+
+ SmallVector<int64_t> mergedPerm;
1572+
+ mergedPerm.reserve(parentPerm.size());
1573+
+ for (int64_t parentIdx : parentPerm) {
1574+
+ mergedPerm.push_back(childPerm[parentIdx]);
1575+
+ }
1576+
+ return b.getDenseI64ArrayAttr(mergedPerm);
1577+
+}
1578+
1579+
// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X)
1580+
struct TransposeIsReshape final : SimplifyOpRewritePattern<TransposeOp> {
1581+
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
1582+
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
1583+
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
1584+
@@ -119,6 +119,8 @@
1585+
def InvertBroadcastDims : NativeCodeCall<"getInvertedBroadcastDimensions($_builder, $0)">;
1586+
1587+
def MergeBroadcastDims : NativeCodeCall<"getMergedBroadcastDimensions($_builder, $0, $1)">;
1588+
+
1589+
+def MergePermutations : NativeCodeCall<"getMergedTransposePermutation($_builder, $0, $1)">;
1590+
1591+
def StableHLO_ConvertOpWithShape : NativeCodeCall<
1592+
"$_builder.create<stablehlo::ConvertOp>($_loc, $0.getType(), $1)">;
1593+
@@ -539,6 +541,12 @@
1594+
: Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims),
1595+
(replaceWithValue $lhs)>;
1596+
1597+
+// Pattern: transpose(transpose(X)) -> transpose(X)
1598+
+def TransposeOp_TransposeOfTranspose
1599+
+ : Pat<(StableHLO_TransposeOp
1600+
+ (StableHLO_TransposeOp $child, $child_dims), $dims),
1601+
+ (StableHLO_TransposeOp $child, (MergePermutations $child_dims, $dims))>;
1602+
+
1603+
////////
1604+
// GetTupleElementOp
1605+
15441606

0 commit comments

Comments
 (0)