Skip to content

Conversation

@matthias-springer
Copy link
Member

unresolvedMaterializations is a mapping from UnrealizedConversionCastOp to UnresolvedMaterializationRewrite. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, unresolvedMaterializations is updated immediately when an op is being erased. This also cleans up the code base a bit: SingleEraseRewriter is now used only during the "cleanup" phase and no longer needed as a field of ConversionRewriterImpl.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: allowPatternRollback = false will in the future trigger immediate materialization of all IR changes.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

unresolvedMaterializations is a mapping from UnrealizedConversionCastOp to UnresolvedMaterializationRewrite. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, unresolvedMaterializations is updated immediately when an op is being erased. This also cleans up the code base a bit: SingleEraseRewriter is now used only during the "cleanup" phase and no longer needed as a field of ConversionRewriterImpl.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: allowPatternRollback = false will in the future trigger immediate materialization of all IR changes.


Full diff: https://github.com/llvm/llvm-project/pull/144254.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-13)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
   struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
   public:
-    SingleEraseRewriter(MLIRContext *context)
-        : RewriterBase(context, /*listener=*/this) {}
+    SingleEraseRewriter(
+        MLIRContext *context,
+        llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+        : RewriterBase(context, /*listener=*/this),
+          opErasedCallback(opErasedCallback) {}
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    void notifyOperationErased(Operation *op) override { erased.insert(op); }
+    void notifyOperationErased(Operation *op) override {
+      erased.insert(op);
+      if (opErasedCallback)
+        opErasedCallback(op);
+    }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
   private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
+
+    /// A callback that is invoked when an operation is erased.
+    llvm::function_ref<void(Operation *)> opErasedCallback;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrites[i]->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(
+      context, /*opErasedCallback=*/[&](Operation *op) {
+        if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+          unresolvedMaterializations.erase(castOp);
+      });
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   SmallVector<UnrealizedConversionCastOp> allCastOps;
   const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
       &materializations = rewriterImpl.unresolvedMaterializations;
-  for (auto it : materializations) {
-    if (rewriterImpl.eraseRewriter.wasErased(it.first))
-      continue;
+  for (auto it : materializations)
     allCastOps.push_back(it.first);
-  }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
   // dialect conversion frameworks. (Not the one that were inserted by

@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

unresolvedMaterializations is a mapping from UnrealizedConversionCastOp to UnresolvedMaterializationRewrite. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, unresolvedMaterializations is updated immediately when an op is being erased. This also cleans up the code base a bit: SingleEraseRewriter is now used only during the "cleanup" phase and no longer needed as a field of ConversionRewriterImpl.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: allowPatternRollback = false will in the future trigger immediate materialization of all IR changes.


Full diff: https://github.com/llvm/llvm-project/pull/144254.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-13)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
   struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
   public:
-    SingleEraseRewriter(MLIRContext *context)
-        : RewriterBase(context, /*listener=*/this) {}
+    SingleEraseRewriter(
+        MLIRContext *context,
+        llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+        : RewriterBase(context, /*listener=*/this),
+          opErasedCallback(opErasedCallback) {}
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    void notifyOperationErased(Operation *op) override { erased.insert(op); }
+    void notifyOperationErased(Operation *op) override {
+      erased.insert(op);
+      if (opErasedCallback)
+        opErasedCallback(op);
+    }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
   private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
+
+    /// A callback that is invoked when an operation is erased.
+    llvm::function_ref<void(Operation *)> opErasedCallback;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrites[i]->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(
+      context, /*opErasedCallback=*/[&](Operation *op) {
+        if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+          unresolvedMaterializations.erase(castOp);
+      });
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   SmallVector<UnrealizedConversionCastOp> allCastOps;
   const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
       &materializations = rewriterImpl.unresolvedMaterializations;
-  for (auto it : materializations) {
-    if (rewriterImpl.eraseRewriter.wasErased(it.first))
-      continue;
+  for (auto it : materializations)
     allCastOps.push_back(it.first);
-  }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
   // dialect conversion frameworks. (Not the one that were inserted by

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conv_erase branch from b780a86 to 7596373 Compare June 18, 2025 09:10
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit 66580f7 into main Jun 18, 2025
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conv_erase branch June 18, 2025 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants