2424#include " mlir/IR/PatternMatch.h"
2525#include " mlir/Interfaces/DestinationStyleOpInterface.h"
2626#include " mlir/Interfaces/TilingInterface.h"
27+ #include " mlir/Rewrite/FrozenRewritePatternSet.h"
28+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2729#include " llvm/ADT/TypeSwitch.h"
2830#include " llvm/Support/Debug.h"
2931#include < optional>
@@ -1315,6 +1317,172 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13151317 return generatedSlices;
13161318}
13171319
1320+ namespace {
1321+
1322+ // ===----------------------------------------------------------------------===//
1323+ // SliceWorklist
1324+ // ===----------------------------------------------------------------------===//
1325+
1326+ // / Struct for tracking the number of stale entries on the worklist and whether
1327+ // / there is a remaining valid entry.
1328+ struct EntryCount {
1329+ bool isValid = true ;
1330+ unsigned count = 0 ;
1331+ };
1332+
1333+ // / A FIFO worklist of operations with efficient removal and set semantics.
1334+ // /
1335+ // / This class maintains a queue of operations and a mapping of operations to
1336+ // / positions in the vector, so that operations can be removed efficiently at
1337+ // / random. When an operation is removed, it is replaced with nullptr. Such
1338+ // / nullptr are skipped when pop'ing elements.
1339+ // /
1340+ // / This is similar to the worklist used by the GreedyPatternRewriteDriver,
1341+ // / except instead FIFO so that slices for fusion can be processed breadth
1342+ // / first.
1343+ class SliceWorklist {
1344+ public:
1345+ SliceWorklist () = default ;
1346+
1347+ // / Push an operation to the end of the worklist. This assumes that
1348+ // / the given operation is not already on the worklist.
1349+ void push (Operation *op);
1350+
1351+ // / Pop the an operation from the end of the worklist. Returns nullptr if
1352+ // / there are no remaining valid operations.
1353+ Operation *pop ();
1354+
1355+ // / Remove an operation from the worklist.
1356+ void remove (Operation *op);
1357+
1358+ protected:
1359+ // / The queue of operations.
1360+ std::deque<Operation *> list;
1361+
1362+ // / A mapping of operations to the number of stale copies in the queue.
1363+ DenseMap<Operation *, EntryCount> map;
1364+ };
1365+
1366+ void SliceWorklist::push (Operation *op) {
1367+ assert (op && " cannot push nullptr to worklist" );
1368+ list.push_back (op);
1369+ EntryCount newCount = map.lookup (op);
1370+ // Because operations are only pushed on creation, valid duplicates are
1371+ // never added.
1372+ assert ((!map.contains (op) || !newCount.isValid ) &&
1373+ " cannot push a duplicate operation" );
1374+ map[op] = {/* isValid=*/ true , newCount.count + 1 };
1375+ }
1376+
1377+ Operation *SliceWorklist::pop () {
1378+ // Pop the front of the queue until we hit a valid entry.
1379+ while (!list.empty ()) {
1380+ Operation *op = list.front ();
1381+ list.pop_front ();
1382+
1383+ EntryCount e = map.lookup (op);
1384+ // If the entry count is greater than 1 or there is no valid entry,
1385+ // this must be a stale entry. Decrement the map entry by one and continue.
1386+ if (e.count > 1 || !e.isValid ) {
1387+ int64_t newCount = e.count - 1 ;
1388+ if (newCount <= 0 )
1389+ map.erase (op);
1390+ else
1391+ map[op] = {e.isValid , static_cast <unsigned int >(newCount)};
1392+ continue ;
1393+ }
1394+
1395+ map.erase (op);
1396+ return op;
1397+ }
1398+ return nullptr ;
1399+ }
1400+
1401+ // Mark the operation as invalid if present. Removal from the map will
1402+ // happen later when popping from the worklist.
1403+ void SliceWorklist::remove (Operation *op) {
1404+ if (!map.contains (op))
1405+ return ;
1406+
1407+ EntryCount e = map.lookup (op);
1408+ map[op] = {/* isValid=*/ false , e.count };
1409+ }
1410+
1411+ // ===----------------------------------------------------------------------===//
1412+ // SliceTrackingListener
1413+ // ===----------------------------------------------------------------------===//
1414+
1415+ // / This class is a listener for tracking the insertion and removal of
1416+ // / `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1417+ // / fusion algorithm to apply cleanup patterns in between fusion steps.
1418+ class SliceTrackingListener : public RewriterBase ::Listener {
1419+ public:
1420+ explicit SliceTrackingListener (
1421+ std::optional<FrozenRewritePatternSet> patterns);
1422+ SliceTrackingListener () = default ;
1423+
1424+ // / Adds the given list of operations to the worklist, and if present, applies
1425+ // / the list of `patterns` to the newly added operations. This only processes
1426+ // / the given operations and any newly inserted ones by the pattern set.
1427+ LogicalResult insertAndApplyPatterns (ArrayRef<Operation *> newOps);
1428+
1429+ // / Add to the new operation worklist if it is an extract_slice.
1430+ void notifyOperationInserted (Operation *op,
1431+ OpBuilder::InsertPoint previous) override ;
1432+
1433+ // / Remove the operation from the worklist.
1434+ void notifyOperationErased (Operation *op) override ;
1435+
1436+ // / Remove the operation from the worklist.
1437+ void notifyOperationReplaced (Operation *op, ValueRange replacement) override ;
1438+
1439+ // / The worklist for this transformation keeps track of the operations that
1440+ // / need to be (re)visited.
1441+ SliceWorklist worklist;
1442+
1443+ private:
1444+ // / Optional pattern set to apply when adding new operations to the worklist.
1445+ std::optional<FrozenRewritePatternSet> patterns = std::nullopt ;
1446+ };
1447+
1448+ SliceTrackingListener::SliceTrackingListener (
1449+ std::optional<FrozenRewritePatternSet> p) {
1450+ patterns = std::move (p);
1451+ }
1452+
1453+ LogicalResult
1454+ SliceTrackingListener::insertAndApplyPatterns (ArrayRef<Operation *> ops) {
1455+ for (Operation *op : ops) {
1456+ if (isa<tensor::ExtractSliceOp>(op))
1457+ worklist.push (op);
1458+ }
1459+
1460+ if (!patterns)
1461+ return success ();
1462+
1463+ GreedyRewriteConfig config;
1464+ config.listener = this ;
1465+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1466+ return applyOpPatternsAndFold (ops, patterns.value (), config);
1467+ }
1468+
1469+ void SliceTrackingListener::notifyOperationInserted (
1470+ Operation *op, OpBuilder::InsertPoint previous) {
1471+ if (!isa<tensor::ExtractSliceOp>(op))
1472+ return ;
1473+ worklist.push (op);
1474+ }
1475+
1476+ void SliceTrackingListener::notifyOperationErased (Operation *op) {
1477+ worklist.remove (op);
1478+ }
1479+
1480+ void SliceTrackingListener::notifyOperationReplaced (Operation *op,
1481+ ValueRange replacement) {
1482+ worklist.remove (op);
1483+ }
1484+ } // namespace
1485+
13181486// / Implementation of tile consumer and fuse producer greedily.
13191487FailureOr<scf::SCFTileAndFuseResult>
13201488mlir::scf::tileConsumerAndFuseProducersUsingSCF (
@@ -1370,33 +1538,33 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
13701538 tensor::ExtractSliceOp candidateSlice;
13711539 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
13721540 };
1373- std::deque<WorklistItem> worklist;
1374- auto addCandidateSlices = [&worklist, &options,
1375- &loops](ArrayRef<Operation *> candidates) {
1376- for (auto candidate : candidates) {
1377- auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378- if (!sliceOp || sliceOp.use_empty ())
1379- continue ;
13801541
1381- auto [fusableProducer, destinationInitArg] =
1382- getUntiledProducerFromSliceSource (&sliceOp.getSourceMutable (), loops);
1383- if (!fusableProducer)
1384- continue ;
1385- std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386- options.fusionControlFn (sliceOp, fusableProducer,
1387- destinationInitArg.has_value ());
1388- if (!controlFnResult)
1389- continue ;
1390- worklist.emplace_back (WorklistItem{sliceOp, controlFnResult.value ()});
1391- }
1392- };
1542+ SliceTrackingListener sliceTracker =
1543+ SliceTrackingListener (options.cleanupPatterns );
13931544
1394- addCandidateSlices (tilingResult->generatedSlices );
1545+ if (failed (
1546+ sliceTracker.insertAndApplyPatterns (tilingResult->generatedSlices ))) {
1547+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1548+ }
13951549 OpBuilder::InsertionGuard g (rewriter);
1396- while (!worklist.empty ()) {
1397- // Traverse the slices in BFS fashion.
1398- WorklistItem worklistItem = worklist.front ();
1399- worklist.pop_front ();
1550+ while (Operation *next = sliceTracker.worklist .pop ()) {
1551+ auto candidateSlice = dyn_cast<tensor::ExtractSliceOp>(next);
1552+ if (!candidateSlice)
1553+ continue ;
1554+
1555+ auto [fusableProducer, destinationInitArg] =
1556+ getUntiledProducerFromSliceSource (&candidateSlice.getSourceMutable (),
1557+ loops);
1558+ if (!fusableProducer)
1559+ continue ;
1560+
1561+ std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1562+ options.fusionControlFn (candidateSlice, fusableProducer,
1563+ destinationInitArg.has_value ());
1564+ if (!controlFnResult)
1565+ continue ;
1566+
1567+ WorklistItem worklistItem = {candidateSlice, controlFnResult.value ()};
14001568
14011569 // The operands of the fused producer might themselved be slices of
14021570 // values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1575,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14071575 if (!fusedResult)
14081576 continue ;
14091577
1578+ SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices ;
1579+
14101580 if (worklistItem.controlFnResult .yieldProducerReplacement ) {
14111581 // Reconstruct and yield all opResult of fusableProducerOp by default. The
14121582 // caller can specific which one to yield by designating optional argument
@@ -1421,20 +1591,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14211591 fusableProducerOp, " failed to replacement value for this "
14221592 " operation from within the tiled loop" );
14231593 }
1424- addCandidateSlices (newSlices.value ());
1594+ worklistCandidates. append (newSlices.value ());
14251595 for (auto [index, result] :
14261596 llvm::enumerate (fusableProducerOp->getResults ())) {
14271597 origValToResultNumber[result] = loops.front ()->getNumResults () -
14281598 fusableProducerOp->getNumResults () +
14291599 index;
14301600 }
14311601 }
1432- addCandidateSlices (fusedResult->generatedSlices );
14331602 if (Operation *tiledAndFusedOp =
14341603 fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
14351604 fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
14361605 tiledAndFusedOps.insert (tiledAndFusedOp);
14371606 }
1607+
1608+ if (failed (sliceTracker.insertAndApplyPatterns (worklistCandidates))) {
1609+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1610+ }
14381611 }
14391612
14401613 DenseMap<Value, Value> replacements;
0 commit comments