2020#include " mlir/IR/Visitors.h"
2121#include " mlir/Rewrite/PatternApplicator.h"
2222#include " llvm/ADT/STLExtras.h"
23- #include " llvm/Support/DebugLog .h"
23+ #include " llvm/Support/Debug .h"
2424#include " llvm/Support/ErrorHandling.h"
2525
2626#define DEBUG_TYPE " walk-rewriter"
2727
2828namespace mlir {
2929
30+ // Find all reachable blocks in the region and add them to the visitedBlocks
31+ // set.
32+ static void findReachableBlocks (Region ®ion,
33+ DenseSet<Block *> &reachableBlocks) {
34+ Block *entryBlock = ®ion.front ();
35+ reachableBlocks.insert (entryBlock);
36+ // Traverse the CFG and add all reachable blocks to the blockList.
37+ SmallVector<Block *> worklist ({entryBlock});
38+ while (!worklist.empty ()) {
39+ Block *block = worklist.pop_back_val ();
40+ Operation *terminator = &block->back ();
41+ for (Block *successor : terminator->getSuccessors ()) {
42+ if (reachableBlocks.contains (successor))
43+ continue ;
44+ worklist.push_back (successor);
45+ reachableBlocks.insert (successor);
46+ }
47+ }
48+ }
49+
3050namespace {
3151struct WalkAndApplyPatternsAction final
3252 : tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -98,21 +118,30 @@ void walkAndApplyPatterns(Operation *op,
98118 regionIt = region->begin ();
99119 if (regionIt != region->end ())
100120 blockIt = regionIt->begin ();
121+ if (!llvm::hasSingleElement (*region))
122+ findReachableBlocks (*region, reachableBlocks);
101123 }
102124 // Advance the iterator to the next reachable operation.
103125 void advance () {
104126 assert (regionIt != region->end ());
105127 hasVisitedRegions = false ;
106128 if (blockIt == regionIt->end ()) {
107129 ++regionIt;
130+ while (regionIt != region->end () &&
131+ !reachableBlocks.contains (&*regionIt))
132+ ++regionIt;
108133 if (regionIt != region->end ())
109134 blockIt = regionIt->begin ();
110135 return ;
111136 }
112137 ++blockIt;
113138 if (blockIt != regionIt->end ()) {
114- LDBG () << " Incrementing block iterator, next op: "
115- << OpWithFlags (&*blockIt, OpPrintingFlags ().skipRegions ());
139+ LLVM_DEBUG ({
140+ llvm::dbgs () << " Incrementing block iterator, next op: "
141+ << OpWithFlags (&*blockIt,
142+ OpPrintingFlags ().skipRegions ())
143+ << " \n " ;
144+ });
116145 }
117146 }
118147 // The region we're iterating over.
@@ -121,14 +150,17 @@ void walkAndApplyPatterns(Operation *op,
121150 Region::iterator regionIt;
122151 // The Operation currently being iterated over.
123152 Block::iterator blockIt;
153+ // The set of blocks that are reachable in the current region.
154+ DenseSet<Block *> reachableBlocks;
124155 // Whether we've visited the nested regions of the current op already.
125156 bool hasVisitedRegions = false ;
126157 };
127158
128159 // Worklist of regions to visit to drive the post-order traversal.
129160 SmallVector<RegionReachableOpIterator> worklist;
130161
131- LDBG () << " Starting walk-based pattern rewrite driver" ;
162+ LLVM_DEBUG (
163+ { llvm::dbgs () << " Starting walk-based pattern rewrite driver\n " ; });
132164 ctx->executeAction <WalkAndApplyPatternsAction>(
133165 [&] {
134166 // Perform a post-order traversal of the regions, visiting each
@@ -173,13 +205,16 @@ void walkAndApplyPatterns(Operation *op,
173205 // would be erased.
174206 it.advance ();
175207
176- LDBG () << " Visiting op: "
177- << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
208+ LLVM_DEBUG ({
209+ llvm::dbgs () << " Visiting op: "
210+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ())
211+ << " \n " ;
212+ });
178213#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
179214 erasedListener.visitedOp = op;
180215#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
181216 if (succeeded (applicator.matchAndRewrite (op, rewriter)))
182- LDBG ( ) << " \t Op matched and rewritten" ;
217+ LLVM_DEBUG ({ llvm::dbgs ( ) << " \t Op matched and rewritten\n " ; }) ;
183218 }
184219 }
185220 },
0 commit comments