Skip to content

Commit 191e7eb

Browse files
authored
[MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver (#154038)
This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: ``` %add = arith.addi %add, %add : i64 ``` are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
1 parent 624b724 commit 191e7eb

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace mlir {
2727
/// This is intended as the simplest and most lightweight pattern rewriter in
2828
/// cases when a simple walk gets the job done.
2929
///
30+
/// The driver will skip unreachable blocks.
31+
///
3032
/// Note: Does not apply patterns to the given operation itself.
3133
void walkAndApplyPatterns(Operation *op,
3234
const FrozenRewritePatternSet &patterns,

mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@
2727

2828
namespace mlir {
2929

30+
// Find all reachable blocks in the region and add them to the visitedBlocks
31+
// set.
32+
static void findReachableBlocks(Region &region,
33+
DenseSet<Block *> &reachableBlocks) {
34+
Block *entryBlock = &region.front();
35+
reachableBlocks.insert(entryBlock);
36+
// Traverse the CFG and add all reachable blocks to the blockList.
37+
SmallVector<Block *> worklist({entryBlock});
38+
while (!worklist.empty()) {
39+
Block *block = worklist.pop_back_val();
40+
Operation *terminator = &block->back();
41+
for (Block *successor : terminator->getSuccessors()) {
42+
if (reachableBlocks.contains(successor))
43+
continue;
44+
worklist.push_back(successor);
45+
reachableBlocks.insert(successor);
46+
}
47+
}
48+
}
49+
3050
namespace {
3151
struct WalkAndApplyPatternsAction final
3252
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -98,13 +118,18 @@ void walkAndApplyPatterns(Operation *op,
98118
regionIt = region->begin();
99119
if (regionIt != region->end())
100120
blockIt = regionIt->begin();
121+
if (!llvm::hasSingleElement(*region))
122+
findReachableBlocks(*region, reachableBlocks);
101123
}
102124
// Advance the iterator to the next reachable operation.
103125
void advance() {
104126
assert(regionIt != region->end());
105127
hasVisitedRegions = false;
106128
if (blockIt == regionIt->end()) {
107129
++regionIt;
130+
while (regionIt != region->end() &&
131+
!reachableBlocks.contains(&*regionIt))
132+
++regionIt;
108133
if (regionIt != region->end())
109134
blockIt = regionIt->begin();
110135
return;
@@ -121,6 +146,8 @@ void walkAndApplyPatterns(Operation *op,
121146
Region::iterator regionIt;
122147
// The Operation currently being iterated over.
123148
Block::iterator blockIt;
149+
// The set of blocks that are reachable in the current region.
150+
DenseSet<Block *> reachableBlocks;
124151
// Whether we've visited the nested regions of the current op already.
125152
bool hasVisitedRegions = false;
126153
};

mlir/test/IR/test-walk-pattern-rewrite-driver.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,23 @@ func.func @erase_nested_block() -> i32 {
119119
}): () -> (i32)
120120
return %a : i32
121121
}
122+
123+
124+
// CHECK-LABEL: func.func @unreachable_replace_with_new_op
125+
// CHECK: "test.new_op"
126+
// CHECK: "test.replace_with_new_op"
127+
// CHECK-SAME: unreachable
128+
// CHECK: "test.new_op"
129+
func.func @unreachable_replace_with_new_op() {
130+
"test.br"()[^bb1] : () -> ()
131+
^bb1:
132+
%a = "test.replace_with_new_op"() : () -> (i32)
133+
"test.br"()[^end] : () -> () // Test jumping over the unreachable block is visited as well.
134+
^unreachable:
135+
%b = "test.replace_with_new_op"() {test.unreachable} : () -> (i32)
136+
return
137+
^end:
138+
%c = "test.replace_with_new_op"() : () -> (i32)
139+
return
140+
}
141+

0 commit comments

Comments
 (0)