From c494a9dbaab7170cd7f260a134cfb324c5ce0c67 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 20 Apr 2025 13:39:20 +0200 Subject: [PATCH 1/2] no rollback flag --- .../mlir/Transforms/DialectConversion.h | 20 +++++++ .../Transforms/Utils/DialectConversion.cpp | 57 ++++++++++++++----- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index b6ab252456e70..b65b3ea971f91 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1219,6 +1219,26 @@ struct ConversionConfig { /// materializations and instead inserts "builtin.unrealized_conversion_cast" /// ops to ensure that the resulting IR is valid. bool buildMaterializations = true; + + /// If set to "true", pattern rollback is allowed. The conversion driver + /// rolls back IR modifications in the following situations. + /// + /// 1. Pattern implementation returns "failure" after modifying IR. + /// 2. Pattern produces IR (in-place modification or new IR) that is illegal + /// and cannot be legalized by subsequent foldings / pattern applications. + /// + /// If set to "false", the conversion driver will produce an LLVM fatal error + /// instead of rolling back IR modifications. Moreover, in case of a failed + /// conversion, the original IR is not restored. The resulting IR may be a + /// mix of original and rewritten IR. (Same as a failed greedy pattern + /// rewrite.) + /// + /// Note: This flag was added in preparation of the One-Shot Dialect + /// Conversion refactoring, which will remove the ability to roll back IR + /// modifications from the conversion driver. Use this flag to ensure that + /// your patterns do not trigger any IR rollbacks. For details, see + /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083. + bool allowPatternRollback = true; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4d250329c6f45..6deedd41bb9ea 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// conversion process succeeds. void applyRewrites(); - /// Reset the state of the rewriter to a previously saved point. - void resetState(RewriterState state); + /// Reset the state of the rewriter to a previously saved point. Optionally, + /// the name of the pattern that triggered the rollback can specified for + /// debugging purposes. + void resetState(RewriterState state, StringRef patternName = ""); /// Append a rewrite. Rewrites are committed upon success and rolled back upon /// failure. @@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } /// Undo the rewrites (motions, splits) one by one in reverse order until - /// "numRewritesToKeep" rewrites remains. - void undoRewrites(unsigned numRewritesToKeep = 0); + /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern + /// that triggered the rollback can specified for debugging purposes. + void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = ""); /// Remap the given values to those with potentially different types. Returns /// success if the values could be remapped, failure otherwise. `valueDiagTag` @@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); } -void ConversionPatternRewriterImpl::resetState(RewriterState state) { +void ConversionPatternRewriterImpl::resetState(RewriterState state, + StringRef patternName) { // Undo any rewrites. - undoRewrites(state.numRewrites); + undoRewrites(state.numRewrites, patternName); // Pop all of the recorded ignored operations that are no longer valid. while (ignoredOps.size() != state.numIgnoredOperations) @@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { replacedOps.pop_back(); } -void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { +void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, + StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { + if (!config.allowPatternRollback && + !isa(rewrite)) { + // Unresolved materializations can always be rolled back (erased). + std::string errorMessage = "pattern '" + std::string(patternName) + + "' rollback of IR modifications requested"; + llvm_unreachable(errorMessage.c_str()); + } rewrite->rollback(); + } rewrites.resize(numRewritesToKeep); } @@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, }); if (config.listener) config.listener->notifyPatternEnd(pattern, failure()); - rewriterImpl.resetState(curState); + rewriterImpl.resetState(curState, pattern.getDebugName()); appliedPatterns.erase(&pattern); }; @@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op, assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); auto result = legalizePatternResult(op, pattern, rewriter, curState); appliedPatterns.erase(&pattern); - if (failed(result)) - rewriterImpl.resetState(curState); + if (failed(result)) { + if (!rewriterImpl.config.allowPatternRollback) + op->emitError("pattern '") + << pattern.getDebugName() + << "' produced IR that could not be legalized"; + rewriterImpl.resetState(curState, pattern.getDebugName()); + } if (config.listener) config.listener->notifyPatternEnd(pattern, result); return result; @@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { ConversionPatternRewriter rewriter(ops.front()->getContext(), config); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - for (auto *op : toConvert) - if (failed(convert(rewriter, op))) - return rewriterImpl.undoRewrites(), failure(); + for (auto *op : toConvert) { + if (failed(convert(rewriter, op))) { + // Dialect conversion failed. + if (rewriterImpl.config.allowPatternRollback) { + // Rollback is allowed: restore the original IR. + rewriterImpl.undoRewrites(); + } else { + // Rollback is not allowed: apply all modifications that have been + // performed so far. + rewriterImpl.applyRewrites(); + } + return failure(); + } + } // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); From 65002df003bf4ed1df058a4a8ee932d19504f6af Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 22 Apr 2025 09:22:09 +0200 Subject: [PATCH 2/2] address comments --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 6deedd41bb9ea..eb3ca01a462a6 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1227,9 +1227,8 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, if (!config.allowPatternRollback && !isa(rewrite)) { // Unresolved materializations can always be rolled back (erased). - std::string errorMessage = "pattern '" + std::string(patternName) + - "' rollback of IR modifications requested"; - llvm_unreachable(errorMessage.c_str()); + llvm::report_fatal_error("pattern '" + patternName + + "' rollback of IR modifications requested"); } rewrite->rollback(); }