2828#include " mlir/Interfaces/TilingInterface.h"
2929#include " mlir/Rewrite/FrozenRewritePatternSet.h"
3030#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31+ #include " llvm/ADT/ScopeExit.h"
3132#include " llvm/ADT/TypeSwitch.h"
3233#include " llvm/Support/Debug.h"
3334#include < optional>
@@ -1467,6 +1468,47 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
14671468 ValueRange replacement) {
14681469 removeOp (op);
14691470}
1471+
1472+ // ===----------------------------------------------------------------------===//
1473+ // ReplacementListener
1474+ // ===----------------------------------------------------------------------===//
1475+
1476+ // / Listener that tracks updates replacements for values which can be mutated.
1477+ // / This listener runs on top of the existing listener for the rewriter,
1478+ // / to make sure external users can still run listeners.
1479+ class ReplacementListener : public RewriterBase ::ForwardingListener {
1480+ public:
1481+ ReplacementListener (DenseMap<Value, Value> &replacements,
1482+ OpBuilder::Listener *listener)
1483+ : ForwardingListener(listener), replacements(replacements) {}
1484+
1485+ void updateReplacementValues (ValueRange origValues,
1486+ ValueRange replaceValues) {
1487+ // This can probably be written better, but just iterates over the map
1488+ // and the new replacements for now.
1489+ for (auto &[key, val] : replacements) {
1490+ for (auto [orig, replace] : llvm::zip_equal (origValues, replaceValues)) {
1491+ if (val == orig) {
1492+ val = replace;
1493+ }
1494+ }
1495+ }
1496+ }
1497+
1498+ void notifyOperationReplaced (Operation *op, Operation *newOp) override {
1499+ ForwardingListener::notifyOperationReplaced (op, newOp);
1500+ updateReplacementValues (op->getResults (), newOp->getResults ());
1501+ }
1502+
1503+ void notifyOperationReplaced (Operation *op, ValueRange values) override {
1504+ ForwardingListener::notifyOperationReplaced (op, values);
1505+ updateReplacementValues (op->getResults (), values);
1506+ }
1507+
1508+ private:
1509+ DenseMap<Value, Value> &replacements;
1510+ };
1511+
14701512} // namespace
14711513
14721514// / Implementation of tile consumer and fuse producer greedily.
@@ -1493,26 +1535,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14931535 for (auto *tiledOp : tilingResult->tiledOps )
14941536 tiledAndFusedOps.insert (tiledOp);
14951537
1538+ DenseMap<Value, Value> replacements;
1539+ for (auto [origVal, replacement] : llvm::zip_equal (
1540+ consumer->getResults (), tilingResult->mergeResult .replacements )) {
1541+ replacements[origVal] = replacement;
1542+ }
1543+
14961544 // If there are no loops generated, fusion is immaterial.
14971545 auto &loops = tilingResult->loops ;
14981546 if (loops.empty ()) {
1499- DenseMap<Value, Value> replacements;
1500- for (auto [origVal, replacement] : llvm::zip_equal (
1501- consumer->getResults (), tilingResult->mergeResult .replacements )) {
1502- replacements[origVal] = replacement;
1503- }
15041547 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
15051548 replacements};
15061549 }
15071550
1508- // To keep track of replacements for now just record the map from the
1509- // original untiled value to the result number of the for loop. Since the
1510- // loop gets potentially replaced during fusion, keeping the value directly
1511- // wont work.
1512- DenseMap<Value, size_t > origValToResultNumber;
1513- for ( auto [index, result] : llvm::enumerate (consumer-> getResults ())) {
1514- origValToResultNumber[result] = index ;
1515- }
1551+ // Since the loop gets potentially replaced during fusion, we need to track
1552+ // the mutation of replacement values. To do this, we attach a listener to
1553+ // update the replacements as they happen.
1554+ OpBuilder::Listener *previousListener = rewriter. getListener ();
1555+ auto resetListener =
1556+ llvm::make_scope_exit ([&]() { rewriter. setListener (previousListener); });
1557+ ReplacementListener replaceListener (replacements, previousListener) ;
1558+ rewriter. setListener (&replaceListener);
15161559
15171560 // 2. Typically, the operands of the tiled operation are slices of the
15181561 // operands of the untiled operation. These are expressed in IR using
@@ -1581,9 +1624,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15811624 worklistCandidates.append (newSlices.value ());
15821625 for (auto [index, result] :
15831626 llvm::enumerate (fusableProducerOp->getResults ())) {
1584- origValToResultNumber [result] = loops.front ()->getNumResults () -
1585- fusableProducerOp ->getNumResults () +
1586- index;
1627+ replacements [result] = loops.front ()->getResult (
1628+ loops. front () ->getNumResults () -
1629+ fusableProducerOp-> getNumResults () + index) ;
15871630 }
15881631 }
15891632 if (Operation *tiledAndFusedOp =
@@ -1597,11 +1640,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15971640 }
15981641 }
15991642
1600- DenseMap<Value, Value> replacements;
1601- for (auto [origVal, resultNumber] : origValToResultNumber) {
1602- replacements[origVal] = loops.front ()->getResult (resultNumber);
1603- }
1604-
16051643 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
16061644 replacements};
16071645}
0 commit comments