|
25 | 25 | #include "mlir/Interfaces/TilingInterface.h" |
26 | 26 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
27 | 27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 28 | +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
28 | 29 | #include "llvm/ADT/ScopeExit.h" |
29 | 30 | #include "llvm/ADT/TypeSwitch.h" |
30 | 31 | #include "llvm/Support/Debug.h" |
@@ -1316,7 +1317,15 @@ getUntiledProducerFromSliceSource(OpOperand *source, |
1316 | 1317 | ArrayRef<LoopLikeOpInterface> loops) { |
1317 | 1318 | std::optional<OpOperand *> destinationIterArg; |
1318 | 1319 | assert(!loops.empty() && "expected non empty loops container"); |
| 1320 | + |
| 1321 | + // The `extractOp` may not reside within the innermost loop, calculate the |
| 1322 | + // distance between it and the last LoopLikeInterfaceOp. Adding this |
| 1323 | + // `distance` to `loopIt` yields the start of the loop. |
1319 | 1324 | auto loopIt = loops.rbegin(); |
| 1325 | + auto parentLoop = source->getOwner()->getParentOfType<LoopLikeOpInterface>(); |
| 1326 | + const LoopLikeOpInterface *it = llvm::find(loops, parentLoop); |
| 1327 | + int64_t distance = std::distance(loops.begin(), it); |
| 1328 | + loopIt += (loops.size() - distance - 1); |
1320 | 1329 | while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) { |
1321 | 1330 | auto iterArg = cast<BlockArgument>(source->get()); |
1322 | 1331 | auto loop = *loopIt; |
@@ -1347,7 +1356,6 @@ mlir::scf::tileAndFuseProducerOfSlice( |
1347 | 1356 |
|
1348 | 1357 | OpBuilder::InsertionGuard g(rewriter); |
1349 | 1358 | rewriter.setInsertionPoint(candidateSliceOp); |
1350 | | - |
1351 | 1359 | // 2. Clone the fused producer |
1352 | 1360 | // 2a. Compute the destination operands to use for the cloned operation. |
1353 | 1361 | SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; |
@@ -1750,6 +1758,13 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( |
1750 | 1758 | replacements}; |
1751 | 1759 | } |
1752 | 1760 |
|
| 1761 | + // The extract_slice op is created in the innermost loop by default. Using |
| 1762 | + // hoistLoopInvariantSubsets improves the position of the extract_slice op |
| 1763 | + // within the loops, allowing the fuse Op to be created in the correct loop. |
| 1764 | + for (LoopLikeOpInterface loop : loops) { |
| 1765 | + (void)hoistLoopInvariantSubsets(rewriter, loop); |
| 1766 | + } |
| 1767 | + |
1753 | 1768 | // Since the loop gets potentially replaced during fusion, we need to track |
1754 | 1769 | // the mutation of replacement values. To do this, we attach a listener to |
1755 | 1770 | // update the replacements as they happen. |
|
0 commit comments