2828#include " mlir/Interfaces/TilingInterface.h"
2929#include " mlir/Rewrite/FrozenRewritePatternSet.h"
3030#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31- #include " llvm/ADT/ScopeExit.h"
3231#include " llvm/ADT/TypeSwitch.h"
3332#include " llvm/Support/Debug.h"
3433#include < optional>
@@ -1468,47 +1467,6 @@ void SliceTrackingListener::notifyOperationReplaced(Operation *op,
14681467 ValueRange replacement) {
14691468 removeOp (op);
14701469}
1471-
1472- // ===----------------------------------------------------------------------===//
1473- // ReplacementListener
1474- // ===----------------------------------------------------------------------===//
1475-
1476- // / Listener that tracks updates replacements for values which can be mutated.
1477- // / This listener runs on top of the existing listener for the rewriter,
1478- // / to make sure external users can still run listeners.
1479- class ReplacementListener : public RewriterBase ::ForwardingListener {
1480- public:
1481- ReplacementListener (DenseMap<Value, Value> &replacements,
1482- OpBuilder::Listener *listener)
1483- : ForwardingListener(listener), replacements(replacements) {}
1484-
1485- void updateReplacementValues (ValueRange origValues,
1486- ValueRange replaceValues) {
1487- // This can probably be written better, but just iterates over the map
1488- // and the new replacements for now.
1489- for (auto &[key, val] : replacements) {
1490- for (auto [orig, replace] : llvm::zip_equal (origValues, replaceValues)) {
1491- if (val == orig) {
1492- val = replace;
1493- }
1494- }
1495- }
1496- }
1497-
1498- void notifyOperationReplaced (Operation *op, Operation *newOp) override {
1499- ForwardingListener::notifyOperationReplaced (op, newOp);
1500- updateReplacementValues (op->getResults (), newOp->getResults ());
1501- }
1502-
1503- void notifyOperationReplaced (Operation *op, ValueRange values) override {
1504- ForwardingListener::notifyOperationReplaced (op, values);
1505- updateReplacementValues (op->getResults (), values);
1506- }
1507-
1508- private:
1509- DenseMap<Value, Value> &replacements;
1510- };
1511-
15121470} // namespace
15131471
15141472// / Implementation of tile consumer and fuse producer greedily.
@@ -1535,27 +1493,26 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15351493 for (auto *tiledOp : tilingResult->tiledOps )
15361494 tiledAndFusedOps.insert (tiledOp);
15371495
1538- DenseMap<Value, Value> replacements;
1539- for (auto [origVal, replacement] : llvm::zip_equal (
1540- consumer->getResults (), tilingResult->mergeResult .replacements )) {
1541- replacements[origVal] = replacement;
1542- }
1543-
15441496 // If there are no loops generated, fusion is immaterial.
15451497 auto &loops = tilingResult->loops ;
15461498 if (loops.empty ()) {
1499+ DenseMap<Value, Value> replacements;
1500+ for (auto [origVal, replacement] : llvm::zip_equal (
1501+ consumer->getResults (), tilingResult->mergeResult .replacements )) {
1502+ replacements[origVal] = replacement;
1503+ }
15471504 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
15481505 replacements};
15491506 }
15501507
1551- // Since the loop gets potentially replaced during fusion, we need to track
1552- // the mutation of replacement values. To do this, we attach a listener to
1553- // update the replacements as they happen.
1554- OpBuilder::Listener *previousListener = rewriter. getListener ();
1555- auto resetListener =
1556- llvm::make_scope_exit ([&]() { rewriter. setListener (previousListener); });
1557- ReplacementListener replaceListener (replacements, previousListener) ;
1558- rewriter. setListener (&replaceListener);
1508+ // To keep track of replacements for now just record the map from the
1509+ // original untiled value to the result number of the for loop. Since the
1510+ // loop gets potentially replaced during fusion, keeping the value directly
1511+ // wont work.
1512+ DenseMap<Value, size_t > origValToResultNumber;
1513+ for ( auto [index, result] : llvm::enumerate (consumer-> getResults ())) {
1514+ origValToResultNumber[result] = index ;
1515+ }
15591516
15601517 // 2. Typically, the operands of the tiled operation are slices of the
15611518 // operands of the untiled operation. These are expressed in IR using
@@ -1624,9 +1581,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
16241581 worklistCandidates.append (newSlices.value ());
16251582 for (auto [index, result] :
16261583 llvm::enumerate (fusableProducerOp->getResults ())) {
1627- replacements [result] = loops.front ()->getResult (
1628- loops. front () ->getNumResults () -
1629- fusableProducerOp-> getNumResults () + index) ;
1584+ origValToResultNumber [result] = loops.front ()->getNumResults () -
1585+ fusableProducerOp ->getNumResults () +
1586+ index;
16301587 }
16311588 }
16321589 if (Operation *tiledAndFusedOp =
@@ -1640,6 +1597,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
16401597 }
16411598 }
16421599
1600+ DenseMap<Value, Value> replacements;
1601+ for (auto [origVal, resultNumber] : origValToResultNumber) {
1602+ replacements[origVal] = loops.front ()->getResult (resultNumber);
1603+ }
1604+
16431605 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
16441606 replacements};
16451607}
0 commit comments