-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver #154038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesThis is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like:
are legal in unreachable code. Full diff: https://github.com/llvm/llvm-project/pull/154038.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
index 6d62ae3dd43dc..7d5c1d5cebb26 100644
--- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -27,6 +27,8 @@ namespace mlir {
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
+/// The driver will skip unreachable blocks.
+///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 52f8ea5472883..8f26a294f6d9b 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -27,6 +27,26 @@
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region ®ion,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = ®ion.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ Block *block = worklist.pop_back_val();
+ while (!worklist.empty()) {
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -90,18 +110,28 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
- // Cursor to track where we're at in the traversal.
- struct Cursor {
- Cursor(Region *region) : region(region) {
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
regionIt = region->begin();
if (regionIt != region->end())
blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
}
- void next() {
+ // Advance the iterator to the next reachable operation.
+ void advance() {
assert(regionIt != region->end());
hasVisitedRegions = false;
if (blockIt == regionIt->end()) {
regionIt++;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ regionIt++;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
return;
}
blockIt++;
@@ -110,14 +140,23 @@ void walkAndApplyPatterns(Operation *op,
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
}
}
+
+ // The region we're iterating over.
Region *region;
+ // The Block currently being iterated over.
Region::iterator regionIt;
+ // The Operation currently being iterated over.
Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
+ // Whether we've visited the nested regions of the current op already.
bool hasVisitedRegions = false;
};
- SmallVector<Cursor> worklist;
+ SmallVector<RegionReachableOpIterator> worklist;
LDBG() << "Starting walk-based pattern rewrite driver";
+ // Perform a post-order traversal of the region, visiting each reachable
+ // operation.
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
for (Region ®ion : op->getRegions()) {
@@ -128,36 +167,37 @@ void walkAndApplyPatterns(Operation *op,
// Prime the worklist with the entry block of this region.
worklist.push_back({®ion});
while (!worklist.empty()) {
- Cursor &cursor = worklist.back();
- if (cursor.regionIt == cursor.region->end()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
// We're done with this region.
worklist.pop_back();
continue;
}
- if (cursor.blockIt == cursor.regionIt->end()) {
+ if (it.blockIt == it.regionIt->end()) {
// We're done with this block.
- cursor.regionIt++;
- if (cursor.regionIt != cursor.region->end())
- cursor.blockIt = cursor.regionIt->begin();
+ it.advance();
continue;
}
- Operation *op = &*cursor.blockIt;
- if (!cursor.hasVisitedRegions) {
- cursor.hasVisitedRegions = true;
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
if (nestedRegion.empty())
continue;
worklist.push_back({&nestedRegion});
}
}
- // If we're not at the back of the worklist, we're visiting a nested
- // region first. We'll come back to this op later.
- if (&cursor != &worklist.back())
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
continue;
// Premptively increment the cursor, in case the current op
// would be erased.
- cursor.next();
+ it.advance();
LDBG() << "Visiting op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index c75c478ec3734..1acff6fdf029e 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -119,3 +119,13 @@ func.func @erase_nested_block() -> i32 {
}): () -> (i32)
return %a : i32
}
+
+
+// CHECK-LABEL: func.func @unreachable_replace_with_new_op
+// CHECK: "test.replace_with_new_op"
+func.func @unreachable_replace_with_new_op() {
+ return
+^unreachable:
+ %a = "test.replace_with_new_op"() : () -> (i32)
+ return
+}
|
f98c1aa
to
c074f74
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR modifies the walkAndApplyPatterns
driver in MLIR to skip visiting unreachable blocks during pattern matching, preventing crashes and infinite loops that can occur when patterns are applied to illegal IR in unreachable code. The change implements a reachability analysis and custom iterator to ensure only reachable operations are visited during the walk.
- Adds reachability analysis to identify reachable blocks before pattern application
- Replaces the simple
region.walk()
with a custom iterator that skips unreachable blocks - Updates test expectations to reflect the new behavior
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | Implements the core reachability analysis and custom iterator logic |
mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h | Updates documentation to reflect the new unreachable block skipping behavior |
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir | Updates test expectations and adds new test case for unreachable blocks |
c074f74
to
fa92472
Compare
e5070af
to
7260f92
Compare
fa92472
to
a3d9296
Compare
7260f92
to
a33b1af
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
…river This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: %add = arith.addi %add, %add : i64 are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
a33b1af
to
7403b6d
Compare
…river (llvm#154038) This is similar to the fix to the greedy driver in llvm#153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: ``` %add = arith.addi %add, %add : i64 ``` are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
…river (#154038) This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: ``` %add = arith.addi %add, %add : i64 ``` are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it.
Operations like:
are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.