-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion #154037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesThis is in preparation of a follow-up change to stop traversing unreachable blocks. Full diff: https://github.com/llvm/llvm-project/pull/154037.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642c943c4..52f8ea5472883 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,12 +13,14 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,83 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Cursor to track where we're at in the traversal.
+ struct Cursor {
+ Cursor(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ }
+ void next() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ regionIt++;
+ return;
+ }
+ blockIt++;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ Region *region;
+ Region::iterator regionIt;
+ Block::iterator blockIt;
+ bool hasVisitedRegions = false;
+ };
+ SmallVector<Cursor> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
for (Region ®ion : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({®ion});
+ while (!worklist.empty()) {
+ Cursor &cursor = worklist.back();
+ if (cursor.regionIt == cursor.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (cursor.blockIt == cursor.regionIt->end()) {
+ // We're done with this block.
+ cursor.regionIt++;
+ if (cursor.regionIt != cursor.region->end())
+ cursor.blockIt = cursor.regionIt->begin();
+ continue;
+ }
+ Operation *op = &*cursor.blockIt;
+ if (!cursor.hasVisitedRegions) {
+ cursor.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we're visiting a nested
+ // region first. We'll come back to this op later.
+ if (&cursor != &worklist.back())
+ continue;
+
+ // Premptively increment the cursor, in case the current op
+ // would be erased.
+ cursor.next();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60671c9b..c75c478ec3734 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
|
f98c1aa
to
c074f74
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the walkAndApplyPatterns
driver to replace recursive traversal with an iterative approach using a worklist. This change is in preparation for future modifications to stop traversing unreachable blocks and addresses a subtle difference in when the iterator is pre-incremented relative to processing nested regions.
Key changes:
- Replaces the recursive
region.walk()
with an iterative post-order traversal using a customRegionReachableOpIterator
- Changes the timing of iterator pre-increment to occur after processing nested regions rather than before
- Updates test expectations to reflect that moved operations are now visited twice instead of once
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | Implements new iterative traversal logic with custom iterator and worklist |
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir | Updates test expectations for the new traversal behavior |
c074f74
to
fa92472
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not NFC because of a subtlety of the early_inc.
...
This sounds perfectly fine IMO
…on (NFC) This is in preparation of a follow-up change to stop traversing unreachable blocks.
fa92472
to
a3d9296
Compare
…on (llvm#154037) This is in preparation of a follow-up change to stop traversing unreachable blocks. This is not NFC because of a subtlety of the early_inc. On a test case like: ``` scf.if %cond { "test.move_after_parent_op"() ({ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () }) : () -> () } ``` We recursively traverse the nested regions, and process an op when the region is done (post-order). We need to pre-increment the iterator before processing an operation in case it gets deleted. However we can do this before or after processing the nested region. This implementation does the latter.
…on (#154037) This is in preparation of a follow-up change to stop traversing unreachable blocks. This is not NFC because of a subtlety of the early_inc. On a test case like: ``` scf.if %cond { "test.move_after_parent_op"() ({ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () }) : () -> () } ``` We recursively traverse the nested regions, and process an op when the region is done (post-order). We need to pre-increment the iterator before processing an operation in case it gets deleted. However we can do this before or after processing the nested region. This implementation does the latter.
This is in preparation of a follow-up change to stop traversing unreachable blocks.
This is not NFC because of a subtlety of the early_inc. On a test case like:
We recursively traverse the nested regions, and process an op when the region is done (post-order).
We need to pre-increment the iterator before processing an operation in case it gets deleted. However
we can do this before or after processing the nested region. This implementation does the latter.