Skip to content

Commit 21a8ff0

Browse files
use -loop-invariant-subset-hoisting in transform fuse.
1 parent b0b4a8e commit 21a8ff0

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Interfaces/TilingInterface.h"
2626
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28+
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
2829
#include "llvm/ADT/ScopeExit.h"
2930
#include "llvm/ADT/TypeSwitch.h"
3031
#include "llvm/Support/Debug.h"
@@ -1316,7 +1317,15 @@ getUntiledProducerFromSliceSource(OpOperand *source,
13161317
ArrayRef<LoopLikeOpInterface> loops) {
13171318
std::optional<OpOperand *> destinationIterArg;
13181319
assert(!loops.empty() && "expected non empty loops container");
1320+
1321+
// The `extractOp` may not reside within the innermost loop, calculate the
1322+
// distance between it and the last LoopLikeInterfaceOp. Adding this
1323+
// `distance` to `loopIt` yields the start of the loop.
13191324
auto loopIt = loops.rbegin();
1325+
auto parentLoop = source->getOwner()->getParentOfType<LoopLikeOpInterface>();
1326+
const LoopLikeOpInterface *it = llvm::find(loops, parentLoop);
1327+
int64_t distance = std::distance(loops.begin(), it);
1328+
loopIt += (loops.size() - distance - 1);
13201329
while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
13211330
auto iterArg = cast<BlockArgument>(source->get());
13221331
auto loop = *loopIt;
@@ -1347,7 +1356,6 @@ mlir::scf::tileAndFuseProducerOfSlice(
13471356

13481357
OpBuilder::InsertionGuard g(rewriter);
13491358
rewriter.setInsertionPoint(candidateSliceOp);
1350-
13511359
// 2. Clone the fused producer
13521360
// 2a. Compute the destination operands to use for the cloned operation.
13531361
SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
@@ -1750,6 +1758,13 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
17501758
replacements};
17511759
}
17521760

1761+
// The extract_slice op is created in the innermost loop by default. Using
1762+
// hoistLoopInvariantSubsets improves the position of the extract_slice op
1763+
// within the loops, allowing the fuse Op to be created in the correct loop.
1764+
for (LoopLikeOpInterface loop : loops) {
1765+
(void)hoistLoopInvariantSubsets(rewriter, loop);
1766+
}
1767+
17531768
// Since the loop gets potentially replaced during fusion, we need to track
17541769
// the mutation of replacement values. To do this, we attach a listener to
17551770
// update the replacements as they happen.

0 commit comments

Comments
 (0)