-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][Transforms] Dialect Conversion Driver without Rollback #151865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect Conversion Driver without Rollback #151865
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
4932ff9
to
e789428
Compare
9818ca9
to
45f831d
Compare
e789428
to
a94080b
Compare
fb442b3
to
371998e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This commit implements an experimental One-Shot Dialect Conversion driver that bypasses the rollback infrastructure when allowPatternRollback
is set to false. This improves performance by materializing IR changes immediately instead of maintaining rollback state, though it requires patterns to avoid API misuse since rollback is no longer possible.
- Implements the new "no rollback" conversion driver that materializes changes immediately
- Adds extensive test coverage for the new driver across integration and conversion tests
- Updates the ConvertToLLVM pass to support the new
allow-pattern-rollback
flag
Reviewed Changes
Copilot reviewed 30 out of 30 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
mlir/lib/Transforms/Utils/DialectConversion.cpp | Core implementation of the One-Shot driver with immediate IR materialization |
mlir/test/lib/Dialect/Test/TestPatterns.cpp | Adds test flag and fixes pattern insertion point issues |
mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp | Adds support for the new flag in the convert-to-llvm pass |
mlir/include/mlir/Transforms/DialectConversion.h | Updates documentation for the experimental flag |
Various test files | Adds test coverage with both rollback and no-rollback modes |
Comments suppressed due to low confidence (1)
mlir/lib/Transforms/Utils/DialectConversion.cpp:1778
- This assertion in getReplacementValues could interfere with performance profiling data collection. The assertion affects the control flow in a function that's likely called frequently during conversion, potentially impacting performance measurements.
// thin air".
371998e
to
09d8ee9
Compare
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit improves the A few selected test cases now run with both the old and the new driver. RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083 Patch is 62.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151865.diff 30 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6e1baaf23fcf7..e6a80435816a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
"Test conversion patterns of only the specified dialects">,
Option<"useDynamic", "dynamic", "bool", "false",
"Use op conversion attributes to configure the conversion">,
+ Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true",
+ "Experimental performance flag to disallow pattern rollback">
];
}
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 4e651a0489899..00903006bb560 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1231,16 +1231,16 @@ struct ConversionConfig {
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
/// and cannot be legalized by subsequent foldings / pattern applications.
///
- /// If set to "false", the conversion driver will produce an LLVM fatal error
- /// instead of rolling back IR modifications. Moreover, in case of a failed
- /// conversion, the original IR is not restored. The resulting IR may be a
- /// mix of original and rewritten IR. (Same as a failed greedy pattern
- /// rewrite.)
+ /// Experimental: If set to "false", the conversion driver will produce an
+ /// LLVM fatal error instead of rolling back IR modifications. Moreover, in
+ /// case of a failed conversion, the original IR is not restored. The
+ /// resulting IR may be a mix of original and rewritten IR. (Same as a failed
+ /// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ /// with ASAN to detect invalid pattern API usage.
///
- /// Note: This flag was added in preparation of the One-Shot Dialect
- /// Conversion refactoring, which will remove the ability to roll back IR
- /// modifications from the conversion driver. Use this flag to ensure that
- /// your patterns do not trigger any IR rollbacks. For details, see
+ /// When pattern rollback is disabled, the conversion driver has to maintain
+ /// less internal state. This is more efficient, but not supported by all
+ /// lowering patterns. For details, see
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
bool allowPatternRollback = true;
};
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4a7fe40..cdb715064b0f7 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -31,7 +31,8 @@ namespace {
class ConvertToLLVMPassInterface {
public:
ConvertToLLVMPassInterface(MLIRContext *context,
- ArrayRef<std::string> filterDialects);
+ ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback = true);
virtual ~ConvertToLLVMPassInterface() = default;
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
MLIRContext *context;
/// List of dialects names to use as filters.
ArrayRef<std::string> filterDialects;
+ /// An experimental flag to disallow pattern rollback. This is more efficient
+ /// but not supported by all lowering patterns.
+ bool allowPatternRollback;
};
/// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
/// Apply the conversion driver.
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
- if (failed(applyPartialConversion(op, *target, *patterns)))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, *target, *patterns, config)))
return failure();
return success();
}
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
patterns);
// Apply the conversion.
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
return failure();
return success();
}
@@ -206,9 +214,11 @@ class ConvertToLLVMPass
std::shared_ptr<ConvertToLLVMPassInterface> impl;
// Choose the pass implementation.
if (useDynamic)
- impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
else
- impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
if (failed(impl->initialize()))
return failure();
this->impl = impl;
@@ -228,8 +238,10 @@ class ConvertToLLVMPass
//===----------------------------------------------------------------------===//
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
- MLIRContext *context, ArrayRef<std::string> filterDialects)
- : context(context), filterDialects(filterDialects) {}
+ MLIRContext *context, ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback)
+ : context(context), filterDialects(filterDialects),
+ allowPatternRollback(allowPatternRollback) {}
void ConvertToLLVMPassInterface::getDependentDialects(
DialectRegistry ®istry) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3a6c719..0e88d31dae8e8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
{tensor, lvlCoords, values, filled, added, count},
EmitCInterface::On);
Operation *parent = getTop(op);
+ rewriter.setInsertionPointAfter(parent);
rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
- rewriter.setInsertionPointAfter(parent);
memref::DeallocOp::create(rewriter, loc, values);
memref::DeallocOp::create(rewriter, loc, filled);
memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c26b4ed46b31..2ae4718bdc867 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -182,15 +182,24 @@ struct ConversionValueMapping {
/// conversions.)
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values)) {
+ if (v.getDefiningOp() != op)
+ return nullptr;
+ }
+ return op;
+}
+
/// A vector of values is a pure type conversion if all values are defined by
/// the same operation and the operation has the `kPureTypeConversionMarker`
/// attribute.
static bool isPureTypeConversion(const ValueVector &values) {
assert(!values.empty() && "expected non-empty value vector");
- Operation *op = values.front().getDefiningOp();
- for (Value v : llvm::drop_begin(values))
- if (v.getDefiningOp() != op)
- return false;
+ Operation *op = getCommonDefiningOp(values);
return op && op->hasAttr(kPureTypeConversionMarker);
}
@@ -841,7 +850,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -863,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// failure.
template <typename RewriteTy, typename... Args>
void appendRewrite(Args &&...args) {
+ assert(config.allowPatternRollback && "appending rewrites is not allowed");
rewrites.push_back(
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
}
@@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasOpReplaced(Operation *op) const;
/// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
+ /// mapping, taking into account only replacements. Perform a best-effort
+ /// search for existing materializations with the desired types.
///
/// If `skipPureTypeConversions` is "true", materializations that are pure
/// type conversions are not considered.
@@ -1066,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
+ /// This vector is maintained only if `allowPatternRollback` is set to
+ /// "true". Otherwise, all IR rewrites are materialized immediately and no
+ /// bookkeeping is needed.
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization.
@@ -1089,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
+ /// A list of unresolved materializations that were created by the current
+ /// pattern.
+ DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
@@ -1104,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// A set of erased operations. This set is utilized only if
+ /// `allowPatternRollback` is set to "false". Conceptually, this set is
+ /// simialar to `replacedOps` (which is maintained when the flag is set to
+ /// "true"). However, erasing from a DenseSet is more efficient than erasing
+ /// from a SetVector.
+ DenseSet<Operation *> erasedOps;
+
+ /// A set of erased blocks. This set is utilized only if
+ /// `allowPatternRollback` is set to "false".
+ DenseSet<Block *> erasedBlocks;
+
+ /// A rewriter that notifies the listener (if any) about all IR
+ /// modifications. This rewriter is utilized only if `allowPatternRollback`
+ /// is set to "false". If the flag is set to "true", the listener is notified
+ /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+ IRRewriter notifyingRewriter;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1140,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
- if (!repl)
- return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
rewriter.replaceAllUsesWith(arg, repl);
return;
@@ -1161,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ if (!repl)
+ return;
+ performReplaceBlockArg(rewriter, arg, repl);
+}
+
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1246,6 +1277,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up a single value.
+ auto lookup = [&](const ValueVector &values) -> ValueVector {
+ assert(!values.empty() && "expected non-empty value vector");
+
+ // If the pattern rollback is enabled, use the mapping to look up the
+ // values.
+ if (config.allowPatternRollback)
+ return mapping.lookup(values);
+
+ // Otherwise, look up values by examining the IR. All replacements have
+ // already been materialized in IR.
+ Operation *op = getCommonDefiningOp(values);
+ if (!op)
+ return {};
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+ if (!castOp)
+ return {};
+ if (!this->unresolvedMaterializations.contains(castOp))
+ return {};
+ if (castOp.getOutputs() != values)
+ return {};
+ return castOp.getInputs();
+ };
+
// Helper function that looks up each value in `values` individually and then
// composes the results. If that fails, it tries to look up the entire vector
// at once.
@@ -1253,7 +1308,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// If possible, replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : values) {
- ValueVector r = mapping.lookup({v});
+ ValueVector r = lookup({v});
if (!r.empty()) {
llvm::append_range(next, r);
} else {
@@ -1273,7 +1328,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
- ValueVector r = mapping.lookup(values);
+ ValueVector r = lookup(values);
if (r.empty()) {
// No mapping found: The lookup stops here.
return {};
@@ -1347,15 +1402,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
- if (!config.allowPatternRollback &&
- !isa<UnresolvedMaterializationRewrite>(rewrite)) {
- // Unresolved materializations can always be rolled back (erased).
- llvm::report_fatal_error("pattern '" + patternName +
- "' rollback of IR modifications requested");
- }
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
rewrite->rollback();
- }
rewrites.resize(numRewritesToKeep);
}
@@ -1419,12 +1467,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ return wasOpReplaced(op) || ignoredOps.count(op);
}
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
- return replacedOps.count(op);
+ return replacedOps.count(op) || erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1508,7 +1556,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// a bit more efficient, so we try to do that when possible.
bool fastPath = !config.listener;
if (fastPath) {
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ if (config.allowPatternRollback)
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
} else {
while (!block->empty())
@@ -1556,7 +1605,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+ if (config.allowPatternRollback)
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1585,23 +1635,32 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// tracking the materialization like we do for other operations.
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- auto convertOp =
+ UnrealizedConversionCastOp convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
if (isPureTypeConversion)
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
- if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
+
+ // Register the materialization.
if (castOp)
*castOp = convertOp;
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
- std::move(valuesToMap));
+ if (config.allowPatternRollback) {
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
+ } else {
+ patternMaterializations.insert(convertOp);
+ }
return convertOp.getResults();
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ assert(config.allowPatternRollback &&
+ "this code path is valid only in rollback mode");
+
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
@@ -1663,26 +1722,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
"attempting to insert into a block within a replaced/erased op");
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+
if (wasDetached) {
- // If the op was detached, it is most likely a newly created op.
- // TODO: If the same op is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same op multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateOperationRewrite>(op);
+ // If the op was detached, it is most ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really fantastic to see this coming together!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: is this change related to this PR? Maybe an attempt to migrate this conversion pass that failed and this change was not reverted?
It seems harmless though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is required to fix a use-after-free ASAN crash. getTop(op)
returns op
. The op is erased by replaceOp(op, adaptor.getTensor())
. The setInsertionPointAfter(parent)
accessed the deallocated op.
I'm using the SparseTensor integration tests for benchmarking in the RFC, so I wanted to make sure that the SparseTensor test suite is working with this PR.
// CHECK-NEXT: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)> | ||
// CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)> | ||
// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)> | ||
// CHECK: builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the captures omitted here and in the test for arith-to-llvm
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new driver inserts two identical unrealized_conversion_casts
here. The old driver inserts only one. With this change, the same CHECK can be used for both the old and the new driver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
09d8ee9
to
2422ce2
Compare
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | FileCheck %s | ||
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s | ||
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER | ||
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need all these -allow-unregistered-dialect
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few test cases that use ops such as "work"(%arg) : (i16) -> ()
. I can clean that up in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but the changes in mlir/lib/Transforms/Utils/DialectConversion.cpp
are tedious to review. I'll trust the correctness from the testing!
2422ce2
to
5ade1f5
Compare
Fix build after #151865.
Fix build after #151865.
This commit improves the
allowPatternRollback
flag handling in the dialect conversion driver. Previously, this flag was used to merely detect cases that are incompatible with the new One-Shot Dialect Conversion driver. This commit implements the driver itself: when the flag is set to "false", all IR changes are materialized immediately, bypassing theIRRewrite
andConversionValueMapping
infrastructure.A few selected test cases now run with both the old and the new driver.
RFC: https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083