-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][linalg] Add transpose support for reshape as consumer fusion #130344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Nirvedh Meshram (nirvedhmeshram) ChangesDuring #129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that. Full diff: https://github.com/llvm/llvm-project/pull/130344.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..a35afc60d6cb0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -816,17 +816,51 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
}
// Create an expanded transpose op.
+// For bubbling a collapse : transpose(collapse_shape),
+// all expanded groups are permuted together. We just permute the reassocation
+// map of the collapse and flatten it. For example,
+//
+// reassociation_map = [[0], [1, 2, 3], [4, 5]]
+// permutation = [2, 0, 1]
+//
+// Becomes
+//
+// permutation = [4, 5, 0 , 1, 2, 3]
+//
+// For sinking expand : expand_shape(transpose),
+// the reassociation map is already permuted hence we inverse permutate and then
+// flatten it. Then we inverse permute it again to get the final expanded
+// transpose permutation. For example,
+//
+// permutation = [2, 0, 1]
+// reassociation_map = [[0, 1], [2], [3, 4, 5]]
+//
+// inverse permutation = [1, 2, 0]
+// applied to reassocation_map and then flattened becomes
+// flatened permutation = [2, 3, 4, 5, 0, 1]
+// final permuation is the inverse of the flattened permutation.
+//
+// Becomes
+//
+// permutation=[4, 5, 0, 1, 2, 3]
+
static Operation *
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
SmallVector<ReassociationIndices> reassociation,
- Value expandedInput, Value output) {
- applyPermutationToVector(reassociation, transposeOp.getPermutation());
+ Value expandedInput, Value output, bool isExpanding) {
+ ArrayRef<int64_t> permutation =
+ isExpanding ? invertPermutationVector(transposeOp.getPermutation())
+ : transposeOp.getPermutation();
+ applyPermutationToVector(reassociation, permutation);
SmallVector<int64_t> newPerm;
for (auto reassoc : reassociation) {
for (auto dim : reassoc) {
newPerm.push_back(dim);
}
}
+ if (isExpanding) {
+ newPerm = invertPermutationVector(newPerm);
+ }
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
output, newPerm);
}
@@ -866,12 +900,13 @@ static Operation *createExpandedOp(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
- SmallVector<ReassociationIndices> reassociation) {
+ SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
- expandedOpOperands[0], outputs[0]);
+ expandedOpOperands[0], outputs[0],
+ isExpanding);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
@@ -994,9 +1029,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices();
- Operation *fusedOp = createExpandedOp(
- rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
- expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
+ Operation *fusedOp =
+ createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
+ outputs, expandedOpIndexingMaps, expansionInfo,
+ reassociationBeforeExpansion, isExpanding);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 4da9c0851ac70..7c2b55ca745ff 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: : tensor<8x33x4xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
@@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
+// -----
+
+func.func @reshape_as_consumer_transpose
+ (%a : tensor<4x210x6xf32>)
+ -> tensor<2x3x4x5x6x7xf32> {
+ %b = tensor.empty() : tensor<6x4x210xf32>
+ %c = linalg.transpose
+ ins(%a : tensor<4x210x6xf32>)
+ outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+ %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+ return %d : tensor<2x3x4x5x6x7xf32>
+}
+// CHECK: func @reshape_as_consumer_transpose
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
+// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
+// CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
+// CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>
+
+
// -----
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -859,37 +882,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// -----
-func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
- %arg1 : tensor<?x?xf32>) ->
- tensor<?x?xf32>
-{
- %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
- tensor<?x7x?x8xf32> into tensor<?x?xf32>
- %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
- outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
- return %1 : tensor<?x?xf32>
+
+func.func @reshape_as_producer_transpose
+ (%a : tensor<4x5x6x7x2x3xf32>)
+ -> tensor<6x4x210xf32> {
+ %b = tensor.empty() : tensor<6x4x210xf32>
+ %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
+ tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
+ %d = linalg.transpose
+ ins(%c : tensor<4x210x6xf32>)
+ outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
+ return %d : tensor<6x4x210xf32>
}
-// CHECK: func @linalg_transpose_reshape_producer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
-// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
-// CHECK: %[[T2:.+]] = linalg.transpose
-// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
-// CHECK-SAME: permutation = [2, 3, 0, 1]
-// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
-// CHECK-SAME: [0, 1], [2, 3]
-// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
-// CHECK: return %[[T3]]
+// CHECK: func @reshape_as_producer_transpose
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
+// CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
+// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
+// CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
+// CHECK: return %[[T2]] : tensor<6x4x210xf32>
+
// -----
|
6ee32fc to
14cb232
Compare
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes lgtm
Signed-off-by: Nirvedh Meshram <[email protected]>
Signed-off-by: Nirvedh Meshram <[email protected]>
Signed-off-by: Nirvedh Meshram <[email protected]>
9adaa31 to
75d238a
Compare
|
That makes sense. Thanks. |
During #129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that.
Also transpose reshape as producer fusion test is updated to static sizes as that is more likely to catch any issues with the permutation vector in the verifier if the shapes dont match up.