Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jun 20, 2025

Add missing listener notifications when erasing nested blocks/operations.

This commit also moves some of the functionality from ConversionPatternRewriter to ConversionPatternRewriterImpl. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in ConversionPatternRewriter should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, ConversionPatternRewriterImpl can be bypassed to some degree, and PatternRewriter::eraseBlock etc. can be used.)

Depends on #145018.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add missing listener notifications when erasing nested blocks/operations.

This commit also moves some of the functionality from ConversionPatternRewriter to ConversionPatternRewriterImpl. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in ConversionPatternRewriter should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, ConversionPatternRewriterImpl can be bypassed to some degree, and PatternRewriter::eraseBlock etc. can be used.)


Full diff: https://github.com/llvm/llvm-project/pull/145030.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-19)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+16-2)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+  for (Operation &op : b)
+    notifyIRErased(listener, op);
+  listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+  for (Region &r : op.getRegions()) {
+    for (Block &b : r) {
+      notifyIRErased(listener, b);
+    }
+  }
+  listener->notifyOperationErased(&op);
+}
+
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
   }
 
   void commit(RewriterBase &rewriter) override {
-    // Erase the block.
     assert(block && "expected block");
-    assert(block->empty() && "expected empty block");
 
-    // Notify the listener that the block is about to be erased.
+    // Notify the listener that the block and its contents are being erased.
     if (auto *listener =
             dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
-      listener->notifyBlockErased(block);
+      notifyIRErased(listener, *block);
   }
 
   void cleanup(RewriterBase &rewriter) override {
+    // Erase the contents of the block.
+    for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+      rewriter.eraseOp(&op);
+    assert(block->empty() && "expected empty block");
+
     // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
 
-  // Notify the listener that the operation (and its nested operations) was
-  // erased.
-  if (listener) {
-    op->walk<WalkOrder::PostOrder>(
-        [&](Operation *op) { listener->notifyOperationErased(op); });
-  }
+  // Notify the listener that the operation and its contents are being erased.
+  if (listener)
+    notifyIRErased(listener, *op);
 
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   // Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+  assert(!wasOpReplaced(block->getParentOp()) &&
+         "attempting to erase a block within a replaced/erased op");
   appendRewrite<EraseBlockRewrite>(block);
 
   // Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
   block->getParent()->getBlocks().remove(block);
+
+  // Mark all nested ops as erased.
+  block->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  assert(!wasOpReplaced(block->getParentOp()) &&
-         "attempting to insert into a region within a replaced/erased op");
+  assert(
+      (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+      "attempting to insert into a region within a replaced/erased op");
   LLVM_DEBUG(
       {
         Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
         }
       });
 
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed. No extra bookkeeping is needed.
+    return;
+  }
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  assert(!impl->wasOpReplaced(block->getParentOp()) &&
-         "attempting to erase a block within a replaced/erased op");
-
-  // Mark all ops for erasure.
-  for (Operation &op : *block)
-    eraseOp(&op);
-
   impl->eraseBlock(block);
 }
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
 
 // -----
 
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
 // CHECK-LABEL: func @circular_mapping()
 //  CHECK-NEXT:   "test.valid"() : () -> ()
 func.func @circular_mapping() {
   // Regression test that used to crash due to circular
-  // unrealized_conversion_cast ops.
-  %0 = "test.erase_op"() : () -> (i64)
+  // unrealized_conversion_cast ops. 
+  %0 = "test.erase_op"() ({
+    "test.dummy_op_lvl_1"() ({
+      "test.dummy_op_lvl_2"() : () -> ()
+    }) : () -> ()
+  }): () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Add missing listener notifications when erasing nested blocks/operations.

This commit also moves some of the functionality from ConversionPatternRewriter to ConversionPatternRewriterImpl. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in ConversionPatternRewriter should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, ConversionPatternRewriterImpl can be bypassed to some degree, and PatternRewriter::eraseBlock etc. can be used.)


Full diff: https://github.com/llvm/llvm-project/pull/145030.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-19)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+16-2)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+  for (Operation &op : b)
+    notifyIRErased(listener, op);
+  listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+  for (Region &r : op.getRegions()) {
+    for (Block &b : r) {
+      notifyIRErased(listener, b);
+    }
+  }
+  listener->notifyOperationErased(&op);
+}
+
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
   }
 
   void commit(RewriterBase &rewriter) override {
-    // Erase the block.
     assert(block && "expected block");
-    assert(block->empty() && "expected empty block");
 
-    // Notify the listener that the block is about to be erased.
+    // Notify the listener that the block and its contents are being erased.
     if (auto *listener =
             dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
-      listener->notifyBlockErased(block);
+      notifyIRErased(listener, *block);
   }
 
   void cleanup(RewriterBase &rewriter) override {
+    // Erase the contents of the block.
+    for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+      rewriter.eraseOp(&op);
+    assert(block->empty() && "expected empty block");
+
     // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
 
-  // Notify the listener that the operation (and its nested operations) was
-  // erased.
-  if (listener) {
-    op->walk<WalkOrder::PostOrder>(
-        [&](Operation *op) { listener->notifyOperationErased(op); });
-  }
+  // Notify the listener that the operation and its contents are being erased.
+  if (listener)
+    notifyIRErased(listener, *op);
 
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   // Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+  assert(!wasOpReplaced(block->getParentOp()) &&
+         "attempting to erase a block within a replaced/erased op");
   appendRewrite<EraseBlockRewrite>(block);
 
   // Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
   block->getParent()->getBlocks().remove(block);
+
+  // Mark all nested ops as erased.
+  block->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  assert(!wasOpReplaced(block->getParentOp()) &&
-         "attempting to insert into a region within a replaced/erased op");
+  assert(
+      (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+      "attempting to insert into a region within a replaced/erased op");
   LLVM_DEBUG(
       {
         Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
         }
       });
 
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed. No extra bookkeeping is needed.
+    return;
+  }
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  assert(!impl->wasOpReplaced(block->getParentOp()) &&
-         "attempting to erase a block within a replaced/erased op");
-
-  // Mark all ops for erasure.
-  for (Operation &op : *block)
-    eraseOp(&op);
-
   impl->eraseBlock(block);
 }
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
 
 // -----
 
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
 // CHECK-LABEL: func @circular_mapping()
 //  CHECK-NEXT:   "test.valid"() : () -> ()
 func.func @circular_mapping() {
   // Regression test that used to crash due to circular
-  // unrealized_conversion_cast ops.
-  %0 = "test.erase_op"() : () -> (i64)
+  // unrealized_conversion_cast ops. 
+  %0 = "test.erase_op"() ({
+    "test.dummy_op_lvl_1"() ({
+      "test.dummy_op_lvl_2"() : () -> ()
+    }) : () -> ()
+  }): () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
 

@matthias-springer matthias-springer force-pushed the users/matthias-springer/erase_notifications branch 2 times, most recently from d65161f to edb49ec Compare June 20, 2025 12:30
Copy link
Contributor

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Base automatically changed from users/matthias-springer/rename_impl to main June 21, 2025 07:43
@matthias-springer matthias-springer force-pushed the users/matthias-springer/erase_notifications branch from edb49ec to 2fe4d14 Compare June 21, 2025 07:44
@matthias-springer matthias-springer merged commit 0921bfd into main Jun 21, 2025
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/erase_notifications branch June 21, 2025 08:44
matthias-springer added a commit that referenced this pull request Jun 23, 2025
…ation (#145155)

Since #145030, `ConversionPatternRewriter::eraseBlock` no longer calls
`ConversionPatternRewriter::eraseOp`. This now happens in the rewriter
impl (during the cleanup phase). Therefore, a safety check in
`replaceOp` can now be simplified.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants