Skip to content

Commit 849abd8

Browse files
[mlir][linalg] Add transpose support for reshape as consumer fusion (#130344)
During #129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that. Also transpose reshape as producer fusion test is updated to static sizes as that is more likely to catch any issues with the permutation vector in the verifier if the shapes dont match up. --------- Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 29f5d5b commit 849abd8

File tree

2 files changed

+81
-51
lines changed

2 files changed

+81
-51
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -811,19 +811,35 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
811811
}
812812

813813
// Create an expanded transpose op.
814-
static Operation *
815-
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
816-
SmallVector<ReassociationIndices> reassociation,
817-
Value expandedInput, Value output) {
818-
applyPermutationToVector(reassociation, transposeOp.getPermutation());
814+
// the reassociation map is already permuted hence we inverse permute and then
815+
// flatten it. Then we inverse permute it again to get the final expanded
816+
// transpose permutation. For example,
817+
//
818+
// permutation = [2, 0, 1]
819+
// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
820+
//
821+
// inverse permutation = [1, 2, 0]
822+
// applied to reassocation_map and then flattened becomes
823+
// flatened permutation = [2, 3, 4, 5, 0, 1]
824+
// final permuation is the inverse of the flattened permutation.
825+
//
826+
// Becomes
827+
//
828+
// permutation=[4, 5, 0, 1, 2, 3]
829+
830+
static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
831+
TransposeOp transposeOp,
832+
Value expandedInput, Value output,
833+
ExpansionInfo &expansionInfo) {
819834
SmallVector<int64_t> newPerm;
820-
for (const auto &reassoc : reassociation) {
821-
for (auto dim : reassoc) {
835+
for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
836+
auto reassoc = expansionInfo.getExpandedDims(perm);
837+
for (int64_t dim : reassoc) {
822838
newPerm.push_back(dim);
823839
}
824840
}
825841
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
826-
output, newPerm);
842+
output, invertPermutationVector(newPerm));
827843
}
828844

829845
// Create an expanded generic op.
@@ -857,16 +873,18 @@ static Operation *createExpandedGenericOp(
857873
// Create an expanded fused op that retains the name for certain ops
858874
// such as fill, copy and transpose and produce a generic op for
859875
// rest of linalg ops.
860-
static Operation *createExpandedOp(
861-
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
862-
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
863-
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
864-
SmallVector<ReassociationIndices> reassociation) {
876+
static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
877+
TypeRange resultTypes,
878+
ArrayRef<Value> expandedOpOperands,
879+
ArrayRef<Value> outputs,
880+
ArrayRef<AffineMap> expandedOpIndexingMaps,
881+
ExpansionInfo &expansionInfo) {
865882

866883
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
867884
.Case<TransposeOp>([&](TransposeOp transposeOp) {
868-
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
869-
expandedOpOperands[0], outputs[0]);
885+
return createExpandedTransposeOp(rewriter, transposeOp,
886+
expandedOpOperands[0], outputs[0],
887+
expansionInfo);
870888
})
871889
.Case<FillOp, CopyOp>([&](Operation *op) {
872890
return clone(rewriter, linalgOp, resultTypes,
@@ -986,12 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
9861004
}
9871005

9881006
TypeRange resultTypes = ValueRange(outputs).getTypes();
989-
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
990-
isExpanding ? expandingReshapeOp.getReassociationIndices()
991-
: collapsingReshapeOp.getReassociationIndices();
992-
Operation *fusedOp = createExpandedOp(
993-
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
994-
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
1007+
Operation *fusedOp =
1008+
createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
1009+
outputs, expandedOpIndexingMaps, expansionInfo);
9951010
// Reshape the result values to their original shape if this is a collapsing
9961011
// reshape folded into its consumer.
9971012
SmallVector<Value> resultVals;

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,37 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
195195
// CHECK-SAME: : tensor<8x33x4xf32>
196196
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
197197
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198-
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198+
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
199199
// CHECK: %[[T2:.+]] = linalg.generic
200200
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
201201
// CHECK-SAME: ["parallel", "parallel", "parallel"]
202202
// CHECK-SAME: ins(%[[T0]], %[[CST]] :
203203
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
204204
// CHECK: return %[[T2]] : tensor<8x33x4xf32>
205205

206+
// -----
207+
208+
func.func @reshape_as_consumer_transpose
209+
(%a : tensor<4x210x6xf32>)
210+
-> tensor<2x3x4x5x6x7xf32> {
211+
%b = tensor.empty() : tensor<6x4x210xf32>
212+
%c = linalg.transpose
213+
ins(%a : tensor<4x210x6xf32>)
214+
outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
215+
%d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
216+
return %d : tensor<2x3x4x5x6x7xf32>
217+
}
218+
// CHECK: func @reshape_as_consumer_transpose
219+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
220+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
221+
// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
222+
// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
223+
// CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
224+
// CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
225+
// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
226+
// CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>
227+
228+
206229
// -----
207230

208231
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
@@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
884907

885908
// -----
886909

887-
func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
888-
%arg1 : tensor<?x?xf32>) ->
889-
tensor<?x?xf32>
890-
{
891-
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
892-
tensor<?x7x?x8xf32> into tensor<?x?xf32>
893-
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
894-
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
895-
return %1 : tensor<?x?xf32>
910+
911+
func.func @reshape_as_producer_transpose
912+
(%a : tensor<4x5x6x7x2x3xf32>)
913+
-> tensor<6x4x210xf32> {
914+
%b = tensor.empty() : tensor<6x4x210xf32>
915+
%c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
916+
tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
917+
%d = linalg.transpose
918+
ins(%c : tensor<4x210x6xf32>)
919+
outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
920+
return %d : tensor<6x4x210xf32>
896921
}
897922

898-
// CHECK: func @linalg_transpose_reshape_producer_fusion
899-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
900-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
901-
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
902-
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
903-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
904-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
905-
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
906-
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
907-
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
908-
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
909-
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
910-
// CHECK: %[[T2:.+]] = linalg.transpose
911-
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
912-
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
913-
// CHECK-SAME: permutation = [2, 3, 0, 1]
914-
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
915-
// CHECK-SAME: [0, 1], [2, 3]
916-
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
917-
// CHECK: return %[[T3]]
923+
// CHECK: func @reshape_as_producer_transpose
924+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
925+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
926+
// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
927+
// CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
928+
// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
929+
// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
930+
// CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
931+
// CHECK: return %[[T2]] : tensor<6x4x210xf32>
932+
918933

919934
// -----
920935

0 commit comments

Comments
 (0)