Skip to content

[mlir][Transforms] Dialect Conversion Driver without Rollback #151865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
"Test conversion patterns of only the specified dialects">,
Option<"useDynamic", "dynamic", "bool", "false",
"Use op conversion attributes to configure the conversion">,
Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true",
"Experimental performance flag to disallow pattern rollback">
];
}

Expand Down
19 changes: 10 additions & 9 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1241,16 +1241,17 @@ struct ConversionConfig {
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
/// and cannot be legalized by subsequent foldings / pattern applications.
///
/// If set to "false", the conversion driver will produce an LLVM fatal error
/// instead of rolling back IR modifications. Moreover, in case of a failed
/// conversion, the original IR is not restored. The resulting IR may be a
/// mix of original and rewritten IR. (Same as a failed greedy pattern
/// rewrite.)
/// Experimental: If set to "false", the conversion driver will produce an
/// LLVM fatal error instead of rolling back IR modifications. Moreover, in
/// case of a failed conversion, the original IR is not restored. The
/// resulting IR may be a mix of original and rewritten IR. (Same as a failed
/// greedy pattern rewrite.) Use the cmake build option
/// `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON` (ideally together with
/// ASAN) to detect invalid pattern API usage.
///
/// Note: This flag was added in preparation of the One-Shot Dialect
/// Conversion refactoring, which will remove the ability to roll back IR
/// modifications from the conversion driver. Use this flag to ensure that
/// your patterns do not trigger any IR rollbacks. For details, see
/// When pattern rollback is disabled, the conversion driver has to maintain
/// less internal state. This is more efficient, but not supported by all
/// lowering patterns. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;

Expand Down
26 changes: 19 additions & 7 deletions mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace {
class ConvertToLLVMPassInterface {
public:
ConvertToLLVMPassInterface(MLIRContext *context,
ArrayRef<std::string> filterDialects);
ArrayRef<std::string> filterDialects,
bool allowPatternRollback = true);
virtual ~ConvertToLLVMPassInterface() = default;

/// Get the dependent dialects used by `convert-to-llvm`.
Expand Down Expand Up @@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
MLIRContext *context;
/// List of dialects names to use as filters.
ArrayRef<std::string> filterDialects;
/// An experimental flag to disallow pattern rollback. This is more efficient
/// but not supported by all lowering patterns.
bool allowPatternRollback;
};

/// This DialectExtension can be attached to the context, which will invoke the
Expand Down Expand Up @@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {

/// Apply the conversion driver.
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
if (failed(applyPartialConversion(op, *target, *patterns)))
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(op, *target, *patterns, config)))
return failure();
return success();
}
Expand Down Expand Up @@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
patterns);

// Apply the conversion.
if (failed(applyPartialConversion(op, target, std::move(patterns))))
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
return failure();
return success();
}
Expand All @@ -206,9 +214,11 @@ class ConvertToLLVMPass
std::shared_ptr<ConvertToLLVMPassInterface> impl;
// Choose the pass implementation.
if (useDynamic)
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
allowPatternRollback);
else
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
allowPatternRollback);
if (failed(impl->initialize()))
return failure();
this->impl = impl;
Expand All @@ -228,8 +238,10 @@ class ConvertToLLVMPass
//===----------------------------------------------------------------------===//

ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
MLIRContext *context, ArrayRef<std::string> filterDialects)
: context(context), filterDialects(filterDialects) {}
MLIRContext *context, ArrayRef<std::string> filterDialects,
bool allowPatternRollback)
: context(context), filterDialects(filterDialects),
allowPatternRollback(allowPatternRollback) {}

void ConvertToLLVMPassInterface::getDependentDialects(
DialectRegistry &registry) {
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is this change related to this PR? Maybe an attempt to migrate this conversion pass that failed and this change was not reverted?

It seems harmless though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required to fix a use-after-free ASAN crash. getTop(op) returns op. The op is erased by replaceOp(op, adaptor.getTensor()). The setInsertionPointAfter(parent) accessed the deallocated op.

I'm using the SparseTensor integration tests for benchmarking in the RFC, so I wanted to make sure that the SparseTensor test suite is working with this PR.

Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
{tensor, lvlCoords, values, filled, added, count},
EmitCInterface::On);
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
rewriter.setInsertionPointAfter(parent);
memref::DeallocOp::create(rewriter, loc, values);
memref::DeallocOp::create(rewriter, loc, filled);
memref::DeallocOp::create(rewriter, loc, added);
Expand Down
Loading