Skip to content

Commit 7e60842

Browse files
IanWood1pstarkcdpr
authored andcommitted
[Global Opt] Add flag to control edge reshape propagation (iree-org#22438)
Adds flag to allow users to selectively disable the changes from iree-org#22320. This is a stopgap for cases where not fusing a transpose into a matmul gives better performance. Signed-off-by: Ian Wood <[email protected]>
1 parent 9ee54cc commit 7e60842

File tree

4 files changed

+52
-7
lines changed

4 files changed

+52
-7
lines changed

compiler/src/iree/compiler/GlobalOptimization/Passes.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ static llvm::cl::opt<bool> clWarnOnUninitializedValues(
6868
"iree-global-opt-enable-warn-on-uninitialized-values",
6969
llvm::cl::desc("Warn on some classes of uses of uninitialized values."),
7070
llvm::cl::init(true));
71+
72+
static llvm::cl::opt<bool> clEnableEdgeReshapePropagation(
73+
"iree-global-opt-experimental-enable-edge-reshape-propagation",
74+
llvm::cl::desc(
75+
"Enables propagation of reshapes on the edges of the program "
76+
"in transpose propagation. This workaround for better performance and "
77+
"will be removed soon."),
78+
llvm::cl::init(false));
79+
7180
void buildGlobalOptExprHoistingPassPipeline(
7281
OpPassManager &passManager, const TransformOptions &transformOptions) {
7382
IREE::Util::ExprHoistingOptions options;
@@ -170,6 +179,8 @@ void buildGlobalOptimizationPassPipeline(
170179
transformOptions.aggressiveTransposePropagation;
171180
options.enableAttentionVTranspose =
172181
clEnableAttentionVTranspose;
182+
options.enableEdgeReshapePropagation =
183+
clEnableEdgeReshapePropagation;
173184
return createPropagateLinalgTransposePass(options);
174185
})
175186
.addPass(IREE::Flow::createCanonicalizePass)

compiler/src/iree/compiler/GlobalOptimization/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def PropagateLinalgTransposePass :
119119
/*default=*/"true", "Enable transposition of attention v operand">,
120120
Option<"enableConvolutionPropagation", "enable-aggressive-propagation-through-conv", "bool",
121121
/*default=*/"false", "enable propagation through convolutions">,
122+
Option<"enableEdgeReshapePropagation", "enable-edge-reshape-propagation", "bool",
123+
/*default=*/"false", "Enable propagation of reshapes on the edges of the program">,
122124
];
123125
}
124126

compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@ class BubbleTransposeThroughCollapseShape
320320
: public OpRewritePattern<linalg::TransposeOp> {
321321
public:
322322
using Base::Base;
323+
BubbleTransposeThroughCollapseShape(MLIRContext *ctx,
324+
bool enableEdgeReshapeProp,
325+
PatternBenefit b = 1)
326+
: OpRewritePattern<linalg::TransposeOp>(ctx, b),
327+
enableEdgeReshapePropagation(enableEdgeReshapeProp) {}
323328

324329
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
325330
PatternRewriter &rewriter) const override {
@@ -336,7 +341,8 @@ class BubbleTransposeThroughCollapseShape
336341
transposeOp, "transpose input is not a single-use collapse shape");
337342
}
338343

339-
if (!isReshapeBlockingFusion(transposeOp,
344+
if (!enableEdgeReshapePropagation &&
345+
!isReshapeBlockingFusion(transposeOp,
340346
collapseOp.getSrc().getDefiningOp())) {
341347
return rewriter.notifyMatchFailure(transposeOp,
342348
"transpose not blocking fusion");
@@ -379,6 +385,9 @@ class BubbleTransposeThroughCollapseShape
379385
rewriter.replaceOp(transposeOp, newReshape);
380386
return success();
381387
}
388+
389+
private:
390+
bool enableEdgeReshapePropagation = true;
382391
};
383392

384393
} // namespace
@@ -523,6 +532,10 @@ class SinkTransposeThroughExpandShape
523532
: public OpRewritePattern<tensor::ExpandShapeOp> {
524533
public:
525534
using Base::Base;
535+
SinkTransposeThroughExpandShape(MLIRContext *ctx, bool enableEdgeReshapeProp,
536+
PatternBenefit b = 1)
537+
: OpRewritePattern<tensor::ExpandShapeOp>(ctx, b),
538+
enableEdgeReshapePropagation(enableEdgeReshapeProp) {}
526539

527540
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
528541
PatternRewriter &rewriter) const override {
@@ -539,7 +552,8 @@ class SinkTransposeThroughExpandShape
539552
expandOp, "expand shape input is not a single-use transpose");
540553
}
541554

542-
if (llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) {
555+
if (!enableEdgeReshapePropagation &&
556+
llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) {
543557
return isReshapeBlockingFusion(transposeOp, consumer);
544558
})) {
545559
return rewriter.notifyMatchFailure(transposeOp,
@@ -588,6 +602,9 @@ class SinkTransposeThroughExpandShape
588602
rewriter.replaceOp(expandOp, originalReshape);
589603
return success();
590604
}
605+
606+
private:
607+
bool enableEdgeReshapePropagation = true;
591608
};
592609

593610
// Fuses a transpose with the input of a linalg.generic op or contraction op.
@@ -1072,7 +1089,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
10721089
if (!testBubblingOnly) {
10731090
RewritePatternSet sinkingPatterns(context);
10741091
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
1075-
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
1092+
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(
1093+
context, enableEdgeReshapePropagation);
10761094
populateNamedOpSinkingPatterns(context, sinkingPatterns);
10771095
populateCommonCanonicalizationPatterns(context, sinkingPatterns);
10781096
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
@@ -1118,7 +1136,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
11181136
return false;
11191137
}
11201138

1121-
if (llvm::none_of(
1139+
if (!enableEdgeReshapePropagation &&
1140+
llvm::none_of(
11221141
consumer->getUsers(), [&](Operation *expandConsumer) {
11231142
return isReshapeBlockingFusion(producer, expandConsumer);
11241143
})) {
@@ -1148,7 +1167,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
11481167
}
11491168
bubblingPatterns.insert<FuseTransposeWithProducerLinalgOp>(
11501169
context, enableAggressivePropagation, enableConvolutionPropagation);
1151-
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(context);
1170+
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(
1171+
context, enableEdgeReshapePropagation);
11521172
bubblingPatterns.add<BubbleTransposeThroughUnaryElementwiseDpsInit>(
11531173
context, /*benefit=*/2);
11541174
bubblingPatterns.insert<ComposeTransposes>(context);
@@ -1197,7 +1217,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
11971217
return false;
11981218
}
11991219

1200-
if (!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
1220+
if (!enableEdgeReshapePropagation &&
1221+
!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
12011222
consumer)) {
12021223
return false;
12031224
}
@@ -1209,7 +1230,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
12091230
linalg::populateFoldReshapeOpsByExpansionPatterns(sinkingPatterns,
12101231
reshapePropagationFn);
12111232
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
1212-
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
1233+
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(
1234+
context, enableEdgeReshapePropagation);
12131235
sinkingPatterns.insert<FuseTransposeWithLinalgOpConsumer>(
12141236
context, enableAggressivePropagation, enableConvolutionPropagation);
12151237
sinkingPatterns.insert<ComposeTransposes>(context);

compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-sinking-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=SINK
44
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-bubbling-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=BUBBLE
55
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-aggressive-propagation-through-conv=true}))" --split-input-file %s | FileCheck %s --check-prefix=CONV
6+
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-edge-reshape-propagation=true}))" %s -o - | FileCheck %s --check-prefix=ENABLE-EDGE-PROP
67

78
util.func public @specialize_transpose_op(%arg0 : tensor<1x2x3xf32>,
89
%empty : tensor<3x2x1xf32>) -> tensor<3x2x1xf32> {
@@ -819,6 +820,11 @@ util.func @dont_propagate_edge_reshapes(%arg0: tensor<10x10x10xi32>) -> tensor<1
819820
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
820821
// CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]]
821822
// CHECK: util.return %[[VAL]]
823+
// ENABLE-EDGE-PROP-LABEL: util.func public @dont_propagate_edge_reshapes
824+
// ENABLE-EDGE-PROP-SAME: %[[ARG0:[0-9a-zA-Z]+]]
825+
// ENABLE-EDGE-PROP: %[[TRANSPOSED:.+]] = linalg.transpose ins(%[[ARG0]]
826+
// ENABLE-EDGE-PROP: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRANSPOSED]]
827+
// ENABLE-EDGE-PROP: util.return %[[COLLAPSED]]
822828

823829
// -----
824830

@@ -833,3 +839,7 @@ util.func public @dont_sink_through_edge_expand_shape(%arg0 : tensor<2x3x4xf32>)
833839
// SINK: %[[TRANSPOSE:.+]] = linalg.transpose
834840
// SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]]
835841
// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>
842+
// ENABLE-EDGE-PROP-LABEL: util.func public @dont_sink_through_edge_expand_shape
843+
// ENABLE-EDGE-PROP: %[[EXP:.+]] = tensor.expand_shape
844+
// ENABLE-EDGE-PROP: %[[RES:.+]] = linalg.transpose
845+
// ENABLE-EDGE-PROP: util.return %[[RES]]

0 commit comments

Comments
 (0)