Skip to content

Conversation

joker-eph
Copy link
Collaborator

@joker-eph joker-eph commented Aug 17, 2025

This is in preparation of a follow-up change to stop traversing unreachable blocks.

This is not NFC because of a subtlety of the early_inc. On a test case like:

  scf.if %cond {
    "test.move_after_parent_op"() ({
      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
    }) : () -> ()
  }

We recursively traverse the nested regions, and process an op when the region is done (post-order).
We need to pre-increment the iterator before processing an operation in case it gets deleted. However
we can do this before or after processing the nested region. This implementation does the latter.

@joker-eph joker-eph requested a review from kuhar August 17, 2025 21:28
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 17, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This is in preparation of a follow-up change to stop traversing unreachable blocks.


Full diff: https://github.com/llvm/llvm-project/pull/154037.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp (+75-10)
  • (modified) mlir/test/IR/test-walk-pattern-rewrite-driver.mlir (+2-2)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642c943c4..52f8ea5472883 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,12 +13,14 @@
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,83 @@ void walkAndApplyPatterns(Operation *op,
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();
 
+  // Cursor to track where we're at in the traversal.
+  struct Cursor {
+    Cursor(Region *region) : region(region) {
+      regionIt = region->begin();
+      if (regionIt != region->end())
+        blockIt = regionIt->begin();
+    }
+    void next() {
+      assert(regionIt != region->end());
+      hasVisitedRegions = false;
+      if (blockIt == regionIt->end()) {
+        regionIt++;
+        return;
+      }
+      blockIt++;
+      if (blockIt != regionIt->end()) {
+        LDBG() << "Incrementing block iterator, next op: "
+               << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+      }
+    }
+    Region *region;
+    Region::iterator regionIt;
+    Block::iterator blockIt;
+    bool hasVisitedRegions = false;
+  };
+  SmallVector<Cursor> worklist;
+
+  LDBG() << "Starting walk-based pattern rewrite driver";
   ctx->executeAction<WalkAndApplyPatternsAction>(
       [&] {
         for (Region &region : op->getRegions()) {
-          region.walk([&](Operation *visitedOp) {
-            LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
-                llvm::dbgs(), OpPrintingFlags().skipRegions());
-                       llvm::dbgs() << "\n";);
+          assert(worklist.empty());
+          if (region.empty())
+            continue;
+
+          // Prime the worklist with the entry block of this region.
+          worklist.push_back({&region});
+          while (!worklist.empty()) {
+            Cursor &cursor = worklist.back();
+            if (cursor.regionIt == cursor.region->end()) {
+              // We're done with this region.
+              worklist.pop_back();
+              continue;
+            }
+            if (cursor.blockIt == cursor.regionIt->end()) {
+              // We're done with this block.
+              cursor.regionIt++;
+              if (cursor.regionIt != cursor.region->end())
+                cursor.blockIt = cursor.regionIt->begin();
+              continue;
+            }
+            Operation *op = &*cursor.blockIt;
+            if (!cursor.hasVisitedRegions) {
+              cursor.hasVisitedRegions = true;
+              for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+                if (nestedRegion.empty())
+                  continue;
+                worklist.push_back({&nestedRegion});
+              }
+            }
+            // If we're not at the back of the worklist, we're visiting a nested
+            // region first. We'll come back to this op later.
+            if (&cursor != &worklist.back())
+              continue;
+
+            // Premptively increment the cursor, in case the current op
+            // would be erased.
+            cursor.next();
+
+            LDBG() << "Visiting op: "
+                   << OpWithFlags(op, OpPrintingFlags().skipRegions());
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-            erasedListener.visitedOp = visitedOp;
+            erasedListener.visitedOp = op;
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-            if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
-              LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
-            }
-          });
+            if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+              LDBG() << "\tOp matched and rewritten";
+          }
         }
       },
       {op});
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60671c9b..c75c478ec3734 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
 }
 
 // Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
 // CHECK-LABEL: func.func @move_after(
 // CHECK: scf.if
 // CHECK: }
 // CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
 // CHECK: return
 func.func @move_after(%cond : i1) {
   scf.if %cond {

@joker-eph joker-eph changed the title [MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (NFC) [MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion Aug 17, 2025
@joker-eph joker-eph requested a review from Copilot August 17, 2025 21:35
@joker-eph joker-eph force-pushed the users/joker-eph/walk_derecurse branch from f98c1aa to c074f74 Compare August 17, 2025 21:40
@joker-eph joker-eph requested review from Copilot and removed request for Copilot August 17, 2025 21:43
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the walkAndApplyPatterns driver to replace recursive traversal with an iterative approach using a worklist. This change is in preparation for future modifications to stop traversing unreachable blocks and addresses a subtle difference in when the iterator is pre-incremented relative to processing nested regions.

Key changes:

  • Replaces the recursive region.walk() with an iterative post-order traversal using a custom RegionReachableOpIterator
  • Changes the timing of iterator pre-increment to occur after processing nested regions rather than before
  • Updates test expectations to reflect that moved operations are now visited twice instead of once

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp Implements new iterative traversal logic with custom iterator and worklist
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir Updates test expectations for the new traversal behavior

@joker-eph joker-eph force-pushed the users/joker-eph/walk_derecurse branch from c074f74 to fa92472 Compare August 17, 2025 21:51
@joker-eph joker-eph requested a review from jpienaar August 17, 2025 21:58
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not NFC because of a subtlety of the early_inc.
...

This sounds perfectly fine IMO

…on (NFC)

This is in preparation of a follow-up change to stop traversing unreachable blocks.
@joker-eph joker-eph force-pushed the users/joker-eph/walk_derecurse branch from fa92472 to a3d9296 Compare August 18, 2025 08:53
@joker-eph joker-eph enabled auto-merge (squash) August 18, 2025 08:53
@joker-eph joker-eph merged commit 16aa283 into main Aug 18, 2025
9 checks passed
@joker-eph joker-eph deleted the users/joker-eph/walk_derecurse branch August 18, 2025 09:07
joker-eph added a commit to joker-eph/llvm-project that referenced this pull request Aug 20, 2025
…on (llvm#154037)

This is in preparation of a follow-up change to stop traversing
unreachable blocks.

This is not NFC because of a subtlety of the early_inc. On a test case
like:

```
  scf.if %cond {
    "test.move_after_parent_op"() ({
      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
    }) : () -> ()
  }
```

We recursively traverse the nested regions, and process an op when the
region is done (post-order).
We need to pre-increment the iterator before processing an operation in
case it gets deleted. However
we can do this before or after processing the nested region. This
implementation does the latter.
tru pushed a commit that referenced this pull request Aug 26, 2025
…on (#154037)

This is in preparation of a follow-up change to stop traversing
unreachable blocks.

This is not NFC because of a subtlety of the early_inc. On a test case
like:

```
  scf.if %cond {
    "test.move_after_parent_op"() ({
      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
    }) : () -> ()
  }
```

We recursively traverse the nested regions, and process an op when the
region is done (post-order).
We need to pre-increment the iterator before processing an operation in
case it gets deleted. However
we can do this before or after processing the nested region. This
implementation does the latter.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants