Skip to content

Commit 8fd70f7

Browse files
committed
[MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver (llvm#154038)
This is similar to the fix to the greedy driver in llvm#153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: ``` %add = arith.addi %add, %add : i64 ``` are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
1 parent b6cd6f2 commit 8fd70f7

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace mlir {
2727
/// This is intended as the simplest and most lightweight pattern rewriter in
2828
/// cases when a simple walk gets the job done.
2929
///
30+
/// The driver will skip unreachable blocks.
31+
///
3032
/// Note: Does not apply patterns to the given operation itself.
3133
void walkAndApplyPatterns(Operation *op,
3234
const FrozenRewritePatternSet &patterns,

mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,33 @@
2020
#include "mlir/IR/Visitors.h"
2121
#include "mlir/Rewrite/PatternApplicator.h"
2222
#include "llvm/ADT/STLExtras.h"
23-
#include "llvm/Support/DebugLog.h"
23+
#include "llvm/Support/Debug.h"
2424
#include "llvm/Support/ErrorHandling.h"
2525

2626
#define DEBUG_TYPE "walk-rewriter"
2727

2828
namespace mlir {
2929

30+
// Find all reachable blocks in the region and add them to the visitedBlocks
31+
// set.
32+
static void findReachableBlocks(Region &region,
33+
DenseSet<Block *> &reachableBlocks) {
34+
Block *entryBlock = &region.front();
35+
reachableBlocks.insert(entryBlock);
36+
// Traverse the CFG and add all reachable blocks to the blockList.
37+
SmallVector<Block *> worklist({entryBlock});
38+
while (!worklist.empty()) {
39+
Block *block = worklist.pop_back_val();
40+
Operation *terminator = &block->back();
41+
for (Block *successor : terminator->getSuccessors()) {
42+
if (reachableBlocks.contains(successor))
43+
continue;
44+
worklist.push_back(successor);
45+
reachableBlocks.insert(successor);
46+
}
47+
}
48+
}
49+
3050
namespace {
3151
struct WalkAndApplyPatternsAction final
3252
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -98,21 +118,30 @@ void walkAndApplyPatterns(Operation *op,
98118
regionIt = region->begin();
99119
if (regionIt != region->end())
100120
blockIt = regionIt->begin();
121+
if (!llvm::hasSingleElement(*region))
122+
findReachableBlocks(*region, reachableBlocks);
101123
}
102124
// Advance the iterator to the next reachable operation.
103125
void advance() {
104126
assert(regionIt != region->end());
105127
hasVisitedRegions = false;
106128
if (blockIt == regionIt->end()) {
107129
++regionIt;
130+
while (regionIt != region->end() &&
131+
!reachableBlocks.contains(&*regionIt))
132+
++regionIt;
108133
if (regionIt != region->end())
109134
blockIt = regionIt->begin();
110135
return;
111136
}
112137
++blockIt;
113138
if (blockIt != regionIt->end()) {
114-
LDBG() << "Incrementing block iterator, next op: "
115-
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
139+
LLVM_DEBUG({
140+
llvm::dbgs() << "Incrementing block iterator, next op: "
141+
<< OpWithFlags(&*blockIt,
142+
OpPrintingFlags().skipRegions())
143+
<< "\n";
144+
});
116145
}
117146
}
118147
// The region we're iterating over.
@@ -121,14 +150,17 @@ void walkAndApplyPatterns(Operation *op,
121150
Region::iterator regionIt;
122151
// The Operation currently being iterated over.
123152
Block::iterator blockIt;
153+
// The set of blocks that are reachable in the current region.
154+
DenseSet<Block *> reachableBlocks;
124155
// Whether we've visited the nested regions of the current op already.
125156
bool hasVisitedRegions = false;
126157
};
127158

128159
// Worklist of regions to visit to drive the post-order traversal.
129160
SmallVector<RegionReachableOpIterator> worklist;
130161

131-
LDBG() << "Starting walk-based pattern rewrite driver";
162+
LLVM_DEBUG(
163+
{ llvm::dbgs() << "Starting walk-based pattern rewrite driver\n"; });
132164
ctx->executeAction<WalkAndApplyPatternsAction>(
133165
[&] {
134166
// Perform a post-order traversal of the regions, visiting each
@@ -173,13 +205,16 @@ void walkAndApplyPatterns(Operation *op,
173205
// would be erased.
174206
it.advance();
175207

176-
LDBG() << "Visiting op: "
177-
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
208+
LLVM_DEBUG({
209+
llvm::dbgs() << "Visiting op: "
210+
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
211+
<< "\n";
212+
});
178213
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
179214
erasedListener.visitedOp = op;
180215
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
181216
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
182-
LDBG() << "\tOp matched and rewritten";
217+
LLVM_DEBUG({ llvm::dbgs() << "\tOp matched and rewritten\n"; });
183218
}
184219
}
185220
},

mlir/test/IR/test-walk-pattern-rewrite-driver.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,23 @@ func.func @erase_nested_block() -> i32 {
119119
}): () -> (i32)
120120
return %a : i32
121121
}
122+
123+
124+
// CHECK-LABEL: func.func @unreachable_replace_with_new_op
125+
// CHECK: "test.new_op"
126+
// CHECK: "test.replace_with_new_op"
127+
// CHECK-SAME: unreachable
128+
// CHECK: "test.new_op"
129+
func.func @unreachable_replace_with_new_op() {
130+
"test.br"()[^bb1] : () -> ()
131+
^bb1:
132+
%a = "test.replace_with_new_op"() : () -> (i32)
133+
"test.br"()[^end] : () -> () // Test jumping over the unreachable block is visited as well.
134+
^unreachable:
135+
%b = "test.replace_with_new_op"() {test.unreachable} : () -> (i32)
136+
return
137+
^end:
138+
%c = "test.replace_with_new_op"() : () -> (i32)
139+
return
140+
}
141+

0 commit comments

Comments
 (0)