Skip to content

Conversation

@nirvedhmeshram
Copy link
Contributor

@nirvedhmeshram nirvedhmeshram commented Mar 7, 2025

During #129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that.
Also transpose reshape as producer fusion test is updated to static sizes as that is more likely to catch any issues with the permutation vector in the verifier if the shapes dont match up.

@llvmbot
Copy link
Member

llvmbot commented Mar 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

During #129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that.


Full diff: https://github.com/llvm/llvm-project/pull/130344.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+43-7)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+45-30)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..a35afc60d6cb0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -816,17 +816,51 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
 }
 
 // Create an expanded transpose op.
+// For bubbling a collapse : transpose(collapse_shape),
+// all expanded groups are permuted together. We just permute the reassocation
+// map of the collapse and flatten it. For example,
+//
+// reassociation_map = [[0], [1, 2, 3], [4, 5]]
+// permutation = [2, 0, 1]
+//
+// Becomes
+//
+// permutation = [4, 5, 0 , 1, 2, 3]
+//
+// For sinking expand : expand_shape(transpose),
+// the reassociation map is already permuted hence we inverse permutate and then
+// flatten it. Then we inverse permute it again to get the final expanded
+// transpose permutation. For example,
+//
+// permutation = [2, 0, 1]
+// reassociation_map = [[0, 1], [2], [3, 4, 5]]
+//
+// inverse permutation = [1, 2, 0]
+// applied to reassocation_map and then flattened becomes
+// flatened permutation = [2, 3, 4, 5, 0, 1]
+// final permuation is the inverse of the flattened permutation.
+//
+// Becomes
+//
+// permutation=[4, 5, 0, 1, 2, 3]
+
 static Operation *
 createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
                           SmallVector<ReassociationIndices> reassociation,
-                          Value expandedInput, Value output) {
-  applyPermutationToVector(reassociation, transposeOp.getPermutation());
+                          Value expandedInput, Value output, bool isExpanding) {
+  ArrayRef<int64_t> permutation =
+      isExpanding ? invertPermutationVector(transposeOp.getPermutation())
+                  : transposeOp.getPermutation();
+  applyPermutationToVector(reassociation, permutation);
   SmallVector<int64_t> newPerm;
   for (auto reassoc : reassociation) {
     for (auto dim : reassoc) {
       newPerm.push_back(dim);
     }
   }
+  if (isExpanding) {
+    newPerm = invertPermutationVector(newPerm);
+  }
   return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
                                       output, newPerm);
 }
@@ -866,12 +900,13 @@ static Operation *createExpandedOp(
     PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
     ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
     ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
-    SmallVector<ReassociationIndices> reassociation) {
+    SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
 
   return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
       .Case<TransposeOp>([&](TransposeOp transposeOp) {
         return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
-                                         expandedOpOperands[0], outputs[0]);
+                                         expandedOpOperands[0], outputs[0],
+                                         isExpanding);
       })
       .Case<FillOp, CopyOp>([&](Operation *op) {
         return clone(rewriter, linalgOp, resultTypes,
@@ -994,9 +1029,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   SmallVector<ReassociationIndices> reassociationBeforeExpansion =
       isExpanding ? expandingReshapeOp.getReassociationIndices()
                   : collapsingReshapeOp.getReassociationIndices();
-  Operation *fusedOp = createExpandedOp(
-      rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
-      expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
+  Operation *fusedOp =
+      createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
+                       outputs, expandedOpIndexingMaps, expansionInfo,
+                       reassociationBeforeExpansion, isExpanding);
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4da9c0851ac70..7c2b55ca745ff 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 // CHECK-SAME:     : tensor<8x33x4xf32>
 //  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
 //      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
 //      CHECK:   %[[T2:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
@@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 // CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
 //      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
 
+// -----
+
+func.func @reshape_as_consumer_transpose
+  (%a :  tensor<4x210x6xf32>)
+    -> tensor<2x3x4x5x6x7xf32> {
+  %b = tensor.empty() : tensor<6x4x210xf32>
+  %c = linalg.transpose
+          ins(%a : tensor<4x210x6xf32>)
+         outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+  return %d : tensor<2x3x4x5x6x7xf32>
+}
+//      CHECK: func @reshape_as_consumer_transpose
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
+//  CHECK-DAG:   %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
+//      CHECK:   %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME:     permutation = [4, 5, 0, 1, 2, 3]
+//      CHECK:   return %[[T2]] : tensor<2x3x4x5x6x7xf32>
+
+
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -859,37 +882,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 
 // -----
 
-func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
-                                              %arg1 : tensor<?x?xf32>) ->
-                                              tensor<?x?xf32>
-{
-  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
-    tensor<?x7x?x8xf32> into tensor<?x?xf32>
-  %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
-       outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
-  return %1 : tensor<?x?xf32>
+
+func.func @reshape_as_producer_transpose
+  (%a :  tensor<4x5x6x7x2x3xf32>)
+    -> tensor<6x4x210xf32> {
+  %b = tensor.empty() : tensor<6x4x210xf32>
+  %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
+    tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
+  %d = linalg.transpose
+          ins(%c : tensor<4x210x6xf32>)
+         outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+  return %d : tensor<6x4x210xf32>
 }
 
-//      CHECK: func @linalg_transpose_reshape_producer_fusion
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
-//  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//  CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//  CHECK-DAG:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-//  CHECK-DAG:   %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
-//      CHECK:   %[[T2:.+]] = linalg.transpose
-// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
-// CHECK-SAME:     outs(%[[T1]] : tensor<?x8x?x7xf32>)
-// CHECK-SAME:   permutation = [2, 3, 0, 1]
-//      CHECK:   %[[T3:.+]] = tensor.collapse_shape %[[T2]]
-// CHECK-SAME:     [0, 1], [2, 3]
-// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
-//      CHECK:   return %[[T3]]
+//      CHECK: func @reshape_as_producer_transpose
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+//      CHECK:   %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME:     permutation = [4, 5, 0, 1, 2, 3]
+//      CHECK:   %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
+//      CHECK:   return %[[T2]] : tensor<6x4x210xf32>
+
 
 // -----
 

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes lgtm

@MaheshRavishankar
Copy link
Contributor

That makes sense. Thanks.

@nirvedhmeshram nirvedhmeshram merged commit 849abd8 into llvm:main Mar 11, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants