Skip to content

Commit 189dc86

Browse files
authored
[Global Opt] Don't propagate edge reshapes (#22320)
Prevent propagating reshapes on the edges of the program in PropagateLinalgTranspose since these reshapes don't block the propagation of transposes. This is similar to what is done in BubbleUpExpandShapes: https://github.com/iree-org/iree/blob/dad4b2d1cab357720a73c05d7bf6e1b946668082/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp#L441-L459 Closes #22312 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 9992921 commit 189dc86

File tree

2 files changed

+66
-34
lines changed

2 files changed

+66
-34
lines changed

compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ static RankedTensorType getPermutedTensorType(RankedTensorType type,
8787
return RankedTensorType::get(permutedShape, type.getElementType());
8888
}
8989

90+
static bool isReshapeBlockingFusion(Operation *producer, Operation *consumer) {
91+
auto isFusableOp = [](Operation *op) {
92+
if (!op) {
93+
return false;
94+
}
95+
return isa_and_nonnull<linalg::LinalgDialect,
96+
IREE::LinalgExt::IREELinalgExtDialect,
97+
tensor::TensorDialect>(op->getDialect());
98+
};
99+
return isFusableOp(producer) && isFusableOp(consumer);
100+
}
101+
90102
//===----------------------------------------------------------------------===//
91103
// Transpose specialization
92104
//===----------------------------------------------------------------------===//
@@ -324,6 +336,12 @@ class BubbleTransposeThroughCollapseShape
324336
transposeOp, "transpose input is not a single-use collapse shape");
325337
}
326338

339+
if (!isReshapeBlockingFusion(transposeOp,
340+
collapseOp.getSrc().getDefiningOp())) {
341+
return rewriter.notifyMatchFailure(transposeOp,
342+
"transpose not blocking fusion");
343+
}
344+
327345
SmallVector<ReassociationIndices> reassociations =
328346
collapseOp.getReassociationIndices();
329347

@@ -521,6 +539,13 @@ class SinkTransposeThroughExpandShape
521539
expandOp, "expand shape input is not a single-use transpose");
522540
}
523541

542+
if (llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) {
543+
return isReshapeBlockingFusion(transposeOp, consumer);
544+
})) {
545+
return rewriter.notifyMatchFailure(transposeOp,
546+
"transpose not blocking fusion");
547+
}
548+
524549
auto invPerm = invertPermutationVector(transposeOp.getPermutation());
525550
SmallVector<ReassociationIndices> reassociations =
526551
expandOp.getReassociationIndices();
@@ -1084,6 +1109,13 @@ void PropagateLinalgTransposePass::runOnOperation() {
10841109
if (!isa<tensor::ExpandShapeOp>(consumer)) {
10851110
return false;
10861111
}
1112+
1113+
if (llvm::none_of(
1114+
consumer->getUsers(), [&](Operation *expandConsumer) {
1115+
return isReshapeBlockingFusion(producer, expandConsumer);
1116+
})) {
1117+
return false;
1118+
}
10871119
// Only propagate if the immediate consumer of the reshape is a
10881120
// transpose.
10891121
return consumer->hasOneUse() &&
@@ -1156,6 +1188,12 @@ void PropagateLinalgTransposePass::runOnOperation() {
11561188
if (!isa<tensor::CollapseShapeOp>(producer)) {
11571189
return false;
11581190
}
1191+
1192+
if (!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
1193+
consumer)) {
1194+
return false;
1195+
}
1196+
11591197
// Require that the immediate producer of the reshape is a transpose.
11601198
return isa_and_nonnull<linalg::TransposeOp>(
11611199
producer->getOperand(0).getDefiningOp());

compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -358,40 +358,6 @@ util.func public @sink_through_expand_shape(%arg0: tensor<?x?x?xf32>) -> tensor<
358358

359359
// -----
360360

361-
util.func public @sink_non_involution_through_expand_shape(%arg0 : tensor<2x3x4xf32>) -> tensor<1x3x4x2xf32> {
362-
%empty = tensor.empty(): tensor<3x4x2xf32>
363-
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>)
364-
outs(%empty : tensor<3x4x2xf32>) permutation = [1, 2, 0]
365-
%expanded = tensor.expand_shape %transposed [[0, 1], [2], [3]] output_shape [1, 3, 4, 2] : tensor<3x4x2xf32> into tensor<1x3x4x2xf32>
366-
util.return %expanded : tensor<1x3x4x2xf32>
367-
}
368-
// SINK-LABEL: util.func public @sink_non_involution_through_expand_shape
369-
// SINK: %[[EXP:.+]] = tensor.expand_shape {{.*}} {{\[\[}}0], [1, 2], [3]]
370-
// SINK-SAME: tensor<2x3x4xf32> into tensor<2x1x3x4xf32>
371-
// SINK: %[[RES:.+]] = linalg.transpose ins(%[[EXP]] : tensor<2x1x3x4xf32>
372-
// SINK-SAME: outs({{.*}} : tensor<1x3x4x2xf32>)
373-
// SINK-SAME: permutation = [1, 2, 3, 0]
374-
// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>
375-
376-
// -----
377-
378-
util.func public @bubble_non_involution_through_collapse_shape(%arg0 : tensor<1x2x3x5x7x11xf32>) -> tensor<35x11x6xf32> {
379-
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x35x11xf32>
380-
%empty = tensor.empty(): tensor<35x11x6xf32>
381-
%transposed = linalg.transpose ins(%collapsed : tensor<6x35x11xf32>)
382-
outs(%empty : tensor<35x11x6xf32>) permutation = [1, 2, 0]
383-
util.return %transposed : tensor<35x11x6xf32>
384-
}
385-
// BUBBLE-LABEL: util.func public @bubble_non_involution_through_collapse_shape
386-
// BUBBLE: %[[T:.+]] = linalg.transpose ins(%{{.*}} : tensor<1x2x3x5x7x11xf32>
387-
// BUBBLE-SAME: outs({{.*}} : tensor<5x7x11x1x2x3xf32>)
388-
// BUBBLE-SAME: permutation = [3, 4, 5, 0, 1, 2]
389-
// BUBBLE: %[[COL:.+]] = tensor.collapse_shape %[[T]] {{\[\[}}0, 1], [2], [3, 4, 5]]
390-
// BUBBLE-SAME: tensor<5x7x11x1x2x3xf32> into tensor<35x11x6xf32>
391-
// BUBBLE: util.return %[[COL]] : tensor<35x11x6xf32>
392-
393-
// -----
394-
395361
util.func public @propagate_transpose_through_unary_elementwise(%arg0 : tensor<2x3x4xf32>) -> tensor<3x4x2xf32> {
396362
%empty = tensor.empty(): tensor<3x4x2xf32>
397363
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>)
@@ -799,3 +765,31 @@ util.func public @dont_reshape_reduction(%arg0: tensor<16x4x4xf32>, %arg1: tenso
799765
// APROP: %[[V1:.+]] = tensor.collapse_shape %[[V0]]
800766
// APROP: %[[V2:.+]] = linalg.matmul ins(%[[V1]]
801767
// APROP: util.return %[[V2]]
768+
769+
// -----
770+
771+
util.func @dont_propagate_edge_reshapes(%arg0: tensor<10x10x10xi32>) -> tensor<10x100xi32> {
772+
%collapsed = tensor.collapse_shape %arg0[[0, 1], [2]] : tensor<10x10x10xi32> into tensor<100x10xi32>
773+
%empty = tensor.empty() : tensor<10x100xi32>
774+
%transpose = linalg.transpose ins(%collapsed : tensor<100x10xi32>) outs(%empty : tensor<10x100xi32>) permutation = [1, 0]
775+
util.return %transpose : tensor<10x100xi32>
776+
}
777+
// CHECK-LABEL: util.func public @dont_propagate_edge_reshapes
778+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
779+
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
780+
// CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]]
781+
// CHECK: util.return %[[VAL]]
782+
783+
// -----
784+
785+
util.func public @dont_sink_through_edge_expand_shape(%arg0 : tensor<2x3x4xf32>) -> tensor<1x3x4x2xf32> {
786+
%empty = tensor.empty(): tensor<3x4x2xf32>
787+
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>)
788+
outs(%empty : tensor<3x4x2xf32>) permutation = [1, 2, 0]
789+
%expanded = tensor.expand_shape %transposed [[0, 1], [2], [3]] output_shape [1, 3, 4, 2] : tensor<3x4x2xf32> into tensor<1x3x4x2xf32>
790+
util.return %expanded : tensor<1x3x4x2xf32>
791+
}
792+
// SINK-LABEL: util.func public @dont_sink_through_edge_expand_shape
793+
// SINK: %[[TRANSPOSE:.+]] = linalg.transpose
794+
// SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]]
795+
// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>

0 commit comments

Comments
 (0)