Skip to content

[mlir][Transforms] Dialect Conversion Driver without Rollback #151865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2025

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 3, 2025

This commit improves the allowPatternRollback flag handling in the dialect conversion driver. Previously, this flag was used to merely detect cases that are incompatible with the new One-Shot Dialect Conversion driver. This commit implements the driver itself: when the flag is set to "false", all IR changes are materialized immediately, bypassing the IRRewrite and ConversionValueMapping infrastructure.

A few selected test cases now run with both the old and the new driver.

RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083

Copy link

github-actions bot commented Aug 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch 2 times, most recently from 4932ff9 to e789428 Compare August 3, 2025 18:10
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conv_no_tc branch 2 times, most recently from 9818ca9 to 45f831d Compare August 6, 2025 11:44
@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch from e789428 to a94080b Compare August 6, 2025 11:59
Base automatically changed from users/matthias-springer/dialect_conv_no_tc to main August 7, 2025 06:41
@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch 3 times, most recently from fb442b3 to 371998e Compare August 9, 2025 09:34
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This commit implements an experimental One-Shot Dialect Conversion driver that bypasses the rollback infrastructure when allowPatternRollback is set to false. This improves performance by materializing IR changes immediately instead of maintaining rollback state, though it requires patterns to avoid API misuse since rollback is no longer possible.

  • Implements the new "no rollback" conversion driver that materializes changes immediately
  • Adds extensive test coverage for the new driver across integration and conversion tests
  • Updates the ConvertToLLVM pass to support the new allow-pattern-rollback flag

Reviewed Changes

Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
mlir/lib/Transforms/Utils/DialectConversion.cpp Core implementation of the One-Shot driver with immediate IR materialization
mlir/test/lib/Dialect/Test/TestPatterns.cpp Adds test flag and fixes pattern insertion point issues
mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp Adds support for the new flag in the convert-to-llvm pass
mlir/include/mlir/Transforms/DialectConversion.h Updates documentation for the experimental flag
Various test files Adds test coverage with both rollback and no-rollback modes
Comments suppressed due to low confidence (1)

mlir/lib/Transforms/Utils/DialectConversion.cpp:1778

  • This assertion in getReplacementValues could interfere with performance profiling data collection. The assertion affects the control flow in a function that's likely called frequently during conversion, potentially impacting performance measurements.
      // thin air".

@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2025

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit improves the allowPatternRollback flag handling in the dialect conversion driver. Previously, this flag was used to merely detect cases that are incompatible with the new One-Shot Dialect Conversion driver. This commit implements the driver itself: when the flag is set to "false", all IR changes are materialized immediately, bypassing the IRRewrite and ConversionValueMapping infrastructure.

A few selected test cases now run with both the old and the new driver.

RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083


Patch is 62.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151865.diff

30 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+2)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+9-9)
  • (modified) mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp (+19-7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+305-59)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir (+4-3)
  • (modified) mlir/test/Conversion/ControlFlowToLLVM/assert.mlir (+1)
  • (modified) mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir (+1)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir (+1)
  • (modified) mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir (+1)
  • (modified) mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir (+1)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+1)
  • (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+1)
  • (modified) mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir (+1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir (+9)
  • (modified) mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir (+8)
  • (modified) mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir (+10-1)
  • (modified) mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir (+9)
  • (modified) mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir (+9)
  • (modified) mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir (+10-2)
  • (modified) mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir (+8)
  • (modified) mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir (+11-1)
  • (modified) mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir (+11)
  • (modified) mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir (+13-3)
  • (modified) mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir (+11)
  • (modified) mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir (+11)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+18-11)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+9-1)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6e1baaf23fcf7..e6a80435816a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
                "Test conversion patterns of only the specified dialects">,
     Option<"useDynamic", "dynamic", "bool", "false",
            "Use op conversion attributes to configure the conversion">,
+    Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true",
+           "Experimental performance flag to disallow pattern rollback">
   ];
 }
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 4e651a0489899..00903006bb560 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1231,16 +1231,16 @@ struct ConversionConfig {
   /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
   ///    and cannot be legalized by subsequent foldings / pattern applications.
   ///
-  /// If set to "false", the conversion driver will produce an LLVM fatal error
-  /// instead of rolling back IR modifications. Moreover, in case of a failed
-  /// conversion, the original IR is not restored. The resulting IR may be a
-  /// mix of original and rewritten IR. (Same as a failed greedy pattern
-  /// rewrite.)
+  /// Experimental: If set to "false", the conversion driver will produce an
+  /// LLVM fatal error instead of rolling back IR modifications. Moreover, in
+  /// case of a failed conversion, the original IR is not restored. The
+  /// resulting IR may be a mix of original and rewritten IR. (Same as a failed
+  /// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  /// with ASAN to detect invalid pattern API usage.
   ///
-  /// Note: This flag was added in preparation of the One-Shot Dialect
-  /// Conversion refactoring, which will remove the ability to roll back IR
-  /// modifications from the conversion driver. Use this flag to ensure that
-  /// your patterns do not trigger any IR rollbacks. For details, see
+  /// When pattern rollback is disabled, the conversion driver has to maintain
+  /// less internal state. This is more efficient, but not supported by all
+  /// lowering patterns. For details, see
   /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
   bool allowPatternRollback = true;
 };
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4a7fe40..cdb715064b0f7 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -31,7 +31,8 @@ namespace {
 class ConvertToLLVMPassInterface {
 public:
   ConvertToLLVMPassInterface(MLIRContext *context,
-                             ArrayRef<std::string> filterDialects);
+                             ArrayRef<std::string> filterDialects,
+                             bool allowPatternRollback = true);
   virtual ~ConvertToLLVMPassInterface() = default;
 
   /// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
   MLIRContext *context;
   /// List of dialects names to use as filters.
   ArrayRef<std::string> filterDialects;
+  /// An experimental flag to disallow pattern rollback. This is more efficient
+  /// but not supported by all lowering patterns.
+  bool allowPatternRollback;
 };
 
 /// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
 
   /// Apply the conversion driver.
   LogicalResult transform(Operation *op, AnalysisManager manager) const final {
-    if (failed(applyPartialConversion(op, *target, *patterns)))
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
+    if (failed(applyPartialConversion(op, *target, *patterns, config)))
       return failure();
     return success();
   }
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
                                               patterns);
 
     // Apply the conversion.
-    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
+    if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
       return failure();
     return success();
   }
@@ -206,9 +214,11 @@ class ConvertToLLVMPass
     std::shared_ptr<ConvertToLLVMPassInterface> impl;
     // Choose the pass implementation.
     if (useDynamic)
-      impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+      impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+                                                    allowPatternRollback);
     else
-      impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+      impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+                                                   allowPatternRollback);
     if (failed(impl->initialize()))
       return failure();
     this->impl = impl;
@@ -228,8 +238,10 @@ class ConvertToLLVMPass
 //===----------------------------------------------------------------------===//
 
 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
-    MLIRContext *context, ArrayRef<std::string> filterDialects)
-    : context(context), filterDialects(filterDialects) {}
+    MLIRContext *context, ArrayRef<std::string> filterDialects,
+    bool allowPatternRollback)
+    : context(context), filterDialects(filterDialects),
+      allowPatternRollback(allowPatternRollback) {}
 
 void ConvertToLLVMPassInterface::getDependentDialects(
     DialectRegistry &registry) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3a6c719..0e88d31dae8e8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
                    {tensor, lvlCoords, values, filled, added, count},
                    EmitCInterface::On);
     Operation *parent = getTop(op);
+    rewriter.setInsertionPointAfter(parent);
     rewriter.replaceOp(op, adaptor.getTensor());
     // Deallocate the buffers on exit of the loop nest.
-    rewriter.setInsertionPointAfter(parent);
     memref::DeallocOp::create(rewriter, loc, values);
     memref::DeallocOp::create(rewriter, loc, filled);
     memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c26b4ed46b31..2ae4718bdc867 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -182,15 +182,24 @@ struct ConversionValueMapping {
 /// conversions.)
 static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
 
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+  assert(!values.empty() && "expected non-empty value vector");
+  Operation *op = values.front().getDefiningOp();
+  for (Value v : llvm::drop_begin(values)) {
+    if (v.getDefiningOp() != op)
+      return nullptr;
+  }
+  return op;
+}
+
 /// A vector of values is a pure type conversion if all values are defined by
 /// the same operation and the operation has the `kPureTypeConversionMarker`
 /// attribute.
 static bool isPureTypeConversion(const ValueVector &values) {
   assert(!values.empty() && "expected non-empty value vector");
-  Operation *op = values.front().getDefiningOp();
-  for (Value v : llvm::drop_begin(values))
-    if (v.getDefiningOp() != op)
-      return false;
+  Operation *op = getCommonDefiningOp(values);
   return op && op->hasAttr(kPureTypeConversionMarker);
 }
 
@@ -841,7 +850,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), config(config) {}
+      : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -863,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// failure.
   template <typename RewriteTy, typename... Args>
   void appendRewrite(Args &&...args) {
+    assert(config.allowPatternRollback && "appending rewrites is not allowed");
     rewrites.push_back(
         std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
   }
@@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   bool wasOpReplaced(Operation *op) const;
 
   /// Lookup the most recently mapped values with the desired types in the
-  /// mapping.
-  ///
-  /// Special cases:
-  /// - If the desired type range is empty, simply return the most recently
-  ///   mapped values.
-  /// - If there is no mapping to the desired types, also return the most
-  ///   recently mapped values.
-  /// - If there is no mapping for the given values at all, return the given
-  ///   value.
+  /// mapping, taking into account only replacements. Perform a best-effort
+  /// search for existing materializations with the desired types.
   ///
   /// If `skipPureTypeConversions` is "true", materializations that are pure
   /// type conversions are not considered.
@@ -1066,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   ConversionValueMapping mapping;
 
   /// Ordered list of block operations (creations, splits, motions).
+  /// This vector is maintained only if `allowPatternRollback` is set to
+  /// "true". Otherwise, all IR rewrites are materialized immediately and no
+  /// bookkeeping is needed.
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
   /// A set of operations that should no longer be considered for legalization.
@@ -1089,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// by the current pattern.
   SetVector<Block *> patternInsertedBlocks;
 
+  /// A list of unresolved materializations that were created by the current
+  /// pattern.
+  DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
   /// A mapping for looking up metadata of unresolved materializations.
   DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       unresolvedMaterializations;
@@ -1104,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Dialect conversion configuration.
   const ConversionConfig &config;
 
+  /// A set of erased operations. This set is utilized only if
+  /// `allowPatternRollback` is set to "false". Conceptually, this set is
+  /// simialar to `replacedOps` (which is maintained when the flag is set to
+  /// "true"). However, erasing from a DenseSet is more efficient than erasing
+  /// from a SetVector.
+  DenseSet<Operation *> erasedOps;
+
+  /// A set of erased blocks. This set is utilized only if
+  /// `allowPatternRollback` is set to "false".
+  DenseSet<Block *> erasedBlocks;
+
+  /// A rewriter that notifies the listener (if any) about all IR
+  /// modifications. This rewriter is utilized only if `allowPatternRollback`
+  /// is set to "false". If the flag is set to "true", the listener is notified
+  /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+  IRRewriter notifyingRewriter;
+
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1140,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
-  if (!repl)
-    return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+                                   Value repl) {
   if (isa<BlockArgument>(repl)) {
     rewriter.replaceAllUsesWith(arg, repl);
     return;
@@ -1161,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   });
 }
 
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+  if (!repl)
+    return;
+  performReplaceBlockArg(rewriter, arg, repl);
+}
+
 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1246,6 +1277,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
 ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
     Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+  // Helper function that looks up a single value.
+  auto lookup = [&](const ValueVector &values) -> ValueVector {
+    assert(!values.empty() && "expected non-empty value vector");
+
+    // If the pattern rollback is enabled, use the mapping to look up the
+    // values.
+    if (config.allowPatternRollback)
+      return mapping.lookup(values);
+
+    // Otherwise, look up values by examining the IR. All replacements have
+    // already been materialized in IR.
+    Operation *op = getCommonDefiningOp(values);
+    if (!op)
+      return {};
+    auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+    if (!castOp)
+      return {};
+    if (!this->unresolvedMaterializations.contains(castOp))
+      return {};
+    if (castOp.getOutputs() != values)
+      return {};
+    return castOp.getInputs();
+  };
+
   // Helper function that looks up each value in `values` individually and then
   // composes the results. If that fails, it tries to look up the entire vector
   // at once.
@@ -1253,7 +1308,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
     // If possible, replace each value with (one or multiple) mapped values.
     ValueVector next;
     for (Value v : values) {
-      ValueVector r = mapping.lookup({v});
+      ValueVector r = lookup({v});
       if (!r.empty()) {
         llvm::append_range(next, r);
       } else {
@@ -1273,7 +1328,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
     // be stored (and looked up) in the mapping. But for performance reasons,
     // we choose to reuse existing IR (when possible) instead of creating it
     // multiple times.
-    ValueVector r = mapping.lookup(values);
+    ValueVector r = lookup(values);
     if (r.empty()) {
       // No mapping found: The lookup stops here.
       return {};
@@ -1347,15 +1402,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
                                                  StringRef patternName) {
   for (auto &rewrite :
-       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
-    if (!config.allowPatternRollback &&
-        !isa<UnresolvedMaterializationRewrite>(rewrite)) {
-      // Unresolved materializations can always be rolled back (erased).
-      llvm::report_fatal_error("pattern '" + patternName +
-                               "' rollback of IR modifications requested");
-    }
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
     rewrite->rollback();
-  }
   rewrites.resize(numRewritesToKeep);
 }
 
@@ -1419,12 +1467,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
   // Check to see if this operation is ignored or was replaced.
-  return replacedOps.count(op) || ignoredOps.count(op);
+  return wasOpReplaced(op) || ignoredOps.count(op);
 }
 
 bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
   // Check to see if this operation was replaced.
-  return replacedOps.count(op);
+  return replacedOps.count(op) || erasedOps.count(op);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1508,7 +1556,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   // a bit more efficient, so we try to do that when possible.
   bool fastPath = !config.listener;
   if (fastPath) {
-    appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+    if (config.allowPatternRollback)
+      appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
     newBlock->getOperations().splice(newBlock->end(), block->getOperations());
   } else {
     while (!block->empty())
@@ -1556,7 +1605,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     replaceUsesOfBlockArgument(origArg, replArgs, converter);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+  if (config.allowPatternRollback)
+    appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1585,23 +1635,32 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   // tracking the materialization like we do for other operations.
   OpBuilder builder(outputTypes.front().getContext());
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
-  auto convertOp =
+  UnrealizedConversionCastOp convertOp =
       UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
   if (isPureTypeConversion)
     convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
-  if (!valuesToMap.empty())
-    mapping.map(valuesToMap, convertOp.getResults());
+
+  // Register the materialization.
   if (castOp)
     *castOp = convertOp;
   unresolvedMaterializations[convertOp] =
       UnresolvedMaterializationInfo(converter, kind, originalType);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
-                                                  std::move(valuesToMap));
+  if (config.allowPatternRollback) {
+    if (!valuesToMap.empty())
+      mapping.map(valuesToMap, convertOp.getResults());
+    appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+                                                    std::move(valuesToMap));
+  } else {
+    patternMaterializations.insert(convertOp);
+  }
   return convertOp.getResults();
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
+  assert(config.allowPatternRollback &&
+         "this code path is valid only in rollback mode");
+
   // Try to find a replacement value with the same type in the conversion value
   // mapping. This includes cached materializations. We try to reuse those
   // instead of generating duplicate IR.
@@ -1663,26 +1722,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
       logger.getOStream() << " (was detached)";
     logger.getOStream() << "\n";
   });
-  assert(!wasOpReplaced(op->getParentOp()) &&
+
+  // In rollback mode, it is easier to misuse the API, so perform extra error
+  // checking.
+  assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
          "attempting to insert into a block within a replaced/erased op");
 
+  // In "no rollback" mode, the listener is always notified immediately.
+  if (!config.allowPatternRollback && config.listener)
+    config.listener->notifyOperationInserted(op, previous);
+
   if (wasDetached) {
-    // If the op was detached, it is most likely a newly created op.
-    // TODO: If the same op is inserted multiple times from a detached state,
-    // the rollback mechanism may erase the same op multiple times. This is a
-    // bug in the rollback-based dialect conversion driver.
-    appendRewrite<CreateOperationRewrite>(op);
+    // If the op was detached, it is most ...
[truncated]

Copy link
Contributor

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really fantastic to see this coming together!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is this change related to this PR? Maybe an attempt to migrate this conversion pass that failed and this change was not reverted?

It seems harmless though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required to fix a use-after-free ASAN crash. getTop(op) returns op. The op is erased by replaceOp(op, adaptor.getTensor()). The setInsertionPointAfter(parent) accessed the deallocated op.

I'm using the SparseTensor integration tests for benchmarking in the RFC, so I wanted to make sure that the SparseTensor test suite is working with this PR.

// CHECK-NEXT: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)>
// CHECK: builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the captures omitted here and in the test for arith-to-llvm?

Copy link
Member Author

@matthias-springer matthias-springer Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new driver inserts two identical unrealized_conversion_casts here. The old driver inserts only one. With this change, the same CHECK can be used for both the old and the new driver.

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch from 09d8ee9 to 2422ce2 Compare August 13, 2025 13:25
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all these -allow-unregistered-dialect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few test cases that use ops such as "work"(%arg) : (i16) -> (). I can clean that up in a separate PR.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but the changes in mlir/lib/Transforms/Utils/DialectConversion.cpp are tedious to review. I'll trust the correctness from the testing!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/prototype_one_shot branch from 2422ce2 to 5ade1f5 Compare August 13, 2025 14:44
@matthias-springer matthias-springer merged commit 7e7c9d9 into main Aug 13, 2025
9 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/prototype_one_shot branch August 13, 2025 15:40
matthias-springer added a commit that referenced this pull request Aug 13, 2025
matthias-springer added a commit that referenced this pull request Aug 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants