@@ -320,6 +320,11 @@ class BubbleTransposeThroughCollapseShape
320320 : public OpRewritePattern<linalg::TransposeOp> {
321321public:
322322 using Base::Base;
323+ BubbleTransposeThroughCollapseShape (MLIRContext *ctx,
324+ bool enableEdgeReshapeProp,
325+ PatternBenefit b = 1 )
326+ : OpRewritePattern<linalg::TransposeOp>(ctx, b),
327+ enableEdgeReshapePropagation (enableEdgeReshapeProp) {}
323328
324329 LogicalResult matchAndRewrite (linalg::TransposeOp transposeOp,
325330 PatternRewriter &rewriter) const override {
@@ -336,7 +341,8 @@ class BubbleTransposeThroughCollapseShape
336341 transposeOp, " transpose input is not a single-use collapse shape" );
337342 }
338343
339- if (!isReshapeBlockingFusion (transposeOp,
344+ if (!enableEdgeReshapePropagation &&
345+ !isReshapeBlockingFusion (transposeOp,
340346 collapseOp.getSrc ().getDefiningOp ())) {
341347 return rewriter.notifyMatchFailure (transposeOp,
342348 " transpose not blocking fusion" );
@@ -379,6 +385,9 @@ class BubbleTransposeThroughCollapseShape
379385 rewriter.replaceOp (transposeOp, newReshape);
380386 return success ();
381387 }
388+
389+ private:
390+ bool enableEdgeReshapePropagation = true ;
382391};
383392
384393} // namespace
@@ -523,6 +532,10 @@ class SinkTransposeThroughExpandShape
523532 : public OpRewritePattern<tensor::ExpandShapeOp> {
524533public:
525534 using Base::Base;
535+ SinkTransposeThroughExpandShape (MLIRContext *ctx, bool enableEdgeReshapeProp,
536+ PatternBenefit b = 1 )
537+ : OpRewritePattern<tensor::ExpandShapeOp>(ctx, b),
538+ enableEdgeReshapePropagation (enableEdgeReshapeProp) {}
526539
527540 LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
528541 PatternRewriter &rewriter) const override {
@@ -539,7 +552,8 @@ class SinkTransposeThroughExpandShape
539552 expandOp, " expand shape input is not a single-use transpose" );
540553 }
541554
542- if (llvm::none_of (expandOp->getUsers (), [&](Operation *consumer) {
555+ if (!enableEdgeReshapePropagation &&
556+ llvm::none_of (expandOp->getUsers (), [&](Operation *consumer) {
543557 return isReshapeBlockingFusion (transposeOp, consumer);
544558 })) {
545559 return rewriter.notifyMatchFailure (transposeOp,
@@ -588,6 +602,9 @@ class SinkTransposeThroughExpandShape
588602 rewriter.replaceOp (expandOp, originalReshape);
589603 return success ();
590604 }
605+
606+ private:
607+ bool enableEdgeReshapePropagation = true ;
591608};
592609
593610// Fuses a transpose with the input of a linalg.generic op or contraction op.
@@ -1072,7 +1089,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
10721089 if (!testBubblingOnly) {
10731090 RewritePatternSet sinkingPatterns (context);
10741091 sinkingPatterns.insert <SinkTransposeThroughExtractSlice>(context);
1075- sinkingPatterns.insert <SinkTransposeThroughExpandShape>(context);
1092+ sinkingPatterns.insert <SinkTransposeThroughExpandShape>(
1093+ context, enableEdgeReshapePropagation);
10761094 populateNamedOpSinkingPatterns (context, sinkingPatterns);
10771095 populateCommonCanonicalizationPatterns (context, sinkingPatterns);
10781096 sinkingPatterns.add <SinkTransposeThroughUnaryElementwiseInput>(
@@ -1118,7 +1136,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
11181136 return false ;
11191137 }
11201138
1121- if (llvm::none_of (
1139+ if (!enableEdgeReshapePropagation &&
1140+ llvm::none_of (
11221141 consumer->getUsers (), [&](Operation *expandConsumer) {
11231142 return isReshapeBlockingFusion (producer, expandConsumer);
11241143 })) {
@@ -1148,7 +1167,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
11481167 }
11491168 bubblingPatterns.insert <FuseTransposeWithProducerLinalgOp>(
11501169 context, enableAggressivePropagation, enableConvolutionPropagation);
1151- bubblingPatterns.insert <BubbleTransposeThroughCollapseShape>(context);
1170+ bubblingPatterns.insert <BubbleTransposeThroughCollapseShape>(
1171+ context, enableEdgeReshapePropagation);
11521172 bubblingPatterns.add <BubbleTransposeThroughUnaryElementwiseDpsInit>(
11531173 context, /* benefit=*/ 2 );
11541174 bubblingPatterns.insert <ComposeTransposes>(context);
@@ -1197,7 +1217,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
11971217 return false ;
11981218 }
11991219
1200- if (!isReshapeBlockingFusion (producer->getOperand (0 ).getDefiningOp (),
1220+ if (!enableEdgeReshapePropagation &&
1221+ !isReshapeBlockingFusion (producer->getOperand (0 ).getDefiningOp (),
12011222 consumer)) {
12021223 return false ;
12031224 }
@@ -1209,7 +1230,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
12091230 linalg::populateFoldReshapeOpsByExpansionPatterns (sinkingPatterns,
12101231 reshapePropagationFn);
12111232 sinkingPatterns.insert <SinkTransposeThroughExtractSlice>(context);
1212- sinkingPatterns.insert <SinkTransposeThroughExpandShape>(context);
1233+ sinkingPatterns.insert <SinkTransposeThroughExpandShape>(
1234+ context, enableEdgeReshapePropagation);
12131235 sinkingPatterns.insert <FuseTransposeWithLinalgOpConsumer>(
12141236 context, enableAggressivePropagation, enableConvolutionPropagation);
12151237 sinkingPatterns.insert <ComposeTransposes>(context);
0 commit comments