diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 6e1baaf23fcf7..e6a80435816a3 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> { "Test conversion patterns of only the specified dialects">, Option<"useDynamic", "dynamic", "bool", "false", "Use op conversion attributes to configure the conversion">, + Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true", + "Experimental performance flag to disallow pattern rollback"> ]; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 4e651a0489899..00903006bb560 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1231,16 +1231,16 @@ struct ConversionConfig { /// 2. Pattern produces IR (in-place modification or new IR) that is illegal /// and cannot be legalized by subsequent foldings / pattern applications. /// - /// If set to "false", the conversion driver will produce an LLVM fatal error - /// instead of rolling back IR modifications. Moreover, in case of a failed - /// conversion, the original IR is not restored. The resulting IR may be a - /// mix of original and rewritten IR. (Same as a failed greedy pattern - /// rewrite.) + /// Experimental: If set to "false", the conversion driver will produce an + /// LLVM fatal error instead of rolling back IR modifications. Moreover, in + /// case of a failed conversion, the original IR is not restored. The + /// resulting IR may be a mix of original and rewritten IR. (Same as a failed + /// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// with ASAN to detect invalid pattern API usage. /// - /// Note: This flag was added in preparation of the One-Shot Dialect - /// Conversion refactoring, which will remove the ability to roll back IR - /// modifications from the conversion driver. Use this flag to ensure that - /// your patterns do not trigger any IR rollbacks. For details, see + /// When pattern rollback is disabled, the conversion driver has to maintain + /// less internal state. This is more efficient, but not supported by all + /// lowering patterns. For details, see /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083. bool allowPatternRollback = true; }; diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index ed5d6d4a7fe40..cdb715064b0f7 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -31,7 +31,8 @@ namespace { class ConvertToLLVMPassInterface { public: ConvertToLLVMPassInterface(MLIRContext *context, - ArrayRef filterDialects); + ArrayRef filterDialects, + bool allowPatternRollback = true); virtual ~ConvertToLLVMPassInterface() = default; /// Get the dependent dialects used by `convert-to-llvm`. @@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface { MLIRContext *context; /// List of dialects names to use as filters. ArrayRef filterDialects; + /// An experimental flag to disallow pattern rollback. This is more efficient + /// but not supported by all lowering patterns. + bool allowPatternRollback; }; /// This DialectExtension can be attached to the context, which will invoke the @@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { /// Apply the conversion driver. LogicalResult transform(Operation *op, AnalysisManager manager) const final { - if (failed(applyPartialConversion(op, *target, *patterns))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, *target, *patterns, config))) return failure(); return success(); } @@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { patterns); // Apply the conversion. - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, target, std::move(patterns), config))) return failure(); return success(); } @@ -206,9 +214,11 @@ class ConvertToLLVMPass std::shared_ptr impl; // Choose the pass implementation. if (useDynamic) - impl = std::make_shared(context, filterDialects); + impl = std::make_shared(context, filterDialects, + allowPatternRollback); else - impl = std::make_shared(context, filterDialects); + impl = std::make_shared(context, filterDialects, + allowPatternRollback); if (failed(impl->initialize())) return failure(); this->impl = impl; @@ -228,8 +238,10 @@ class ConvertToLLVMPass //===----------------------------------------------------------------------===// ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( - MLIRContext *context, ArrayRef filterDialects) - : context(context), filterDialects(filterDialects) {} + MLIRContext *context, ArrayRef filterDialects, + bool allowPatternRollback) + : context(context), filterDialects(filterDialects), + allowPatternRollback(allowPatternRollback) {} void ConvertToLLVMPassInterface::getDependentDialects( DialectRegistry ®istry) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 830905495e759..221f95a8d8f33 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -458,6 +458,22 @@ struct LinalgDetensorize } }; + /// A listener that forwards notifyBlockErased and notifyOperationErased to + /// the given callbacks. + struct CallbackListener : public RewriterBase::Listener { + CallbackListener(std::function onOperationErased, + std::function onBlockErased) + : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {} + + void notifyBlockErased(Block *block) override { onBlockErased(block); } + void notifyOperationErased(Operation *op) override { + onOperationErased(op); + } + + std::function onOperationErased; + std::function onBlockErased; + }; + void runOnOperation() override { MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; @@ -551,8 +567,22 @@ struct LinalgDetensorize populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + auto onOperationErased = [&](Operation *op) { + opsToDetensor.erase(op); + detensorableBranchOps.erase(op); + }; + auto onBlockErased = [&](Block *block) { + for (BlockArgument arg : block->getArguments()) { + blockArgsToDetensor.erase(arg); + } + }; + CallbackListener listener(onOperationErased, onBlockErased); + + config.listener = &listener; + config.allowPatternRollback = false; + if (failed(applyFullConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); RewritePatternSet canonPatterns(context); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 134aef3a6c719..0e88d31dae8e8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern { {tensor, lvlCoords, values, filled, added, count}, EmitCInterface::On); Operation *parent = getTop(op); + rewriter.setInsertionPointAfter(parent); rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. - rewriter.setInsertionPointAfter(parent); memref::DeallocOp::create(rewriter, loc, values); memref::DeallocOp::create(rewriter, loc, filled); memref::DeallocOp::create(rewriter, loc, added); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 0c26b4ed46b31..2ae4718bdc867 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -182,15 +182,24 @@ struct ConversionValueMapping { /// conversions.) static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; +/// Return the operation that defines all values in the vector. Return nullptr +/// if the values are not defined by the same operation. +static Operation *getCommonDefiningOp(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) { + if (v.getDefiningOp() != op) + return nullptr; + } + return op; +} + /// A vector of values is a pure type conversion if all values are defined by /// the same operation and the operation has the `kPureTypeConversionMarker` /// attribute. static bool isPureTypeConversion(const ValueVector &values) { assert(!values.empty() && "expected non-empty value vector"); - Operation *op = values.front().getDefiningOp(); - for (Value v : llvm::drop_begin(values)) - if (v.getDefiningOp() != op) - return false; + Operation *op = getCommonDefiningOp(values); return op && op->hasAttr(kPureTypeConversionMarker); } @@ -841,7 +850,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : context(ctx), config(config) {} + : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {} //===--------------------------------------------------------------------===// // State Management @@ -863,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template void appendRewrite(Args &&...args) { + assert(config.allowPatternRollback && "appending rewrites is not allowed"); rewrites.push_back( std::make_unique(*this, std::forward(args)...)); } @@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. + /// mapping, taking into account only replacements. Perform a best-effort + /// search for existing materializations with the desired types. /// /// If `skipPureTypeConversions` is "true", materializations that are pure /// type conversions are not considered. @@ -1066,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). + /// This vector is maintained only if `allowPatternRollback` is set to + /// "true". Otherwise, all IR rewrites are materialized immediately and no + /// bookkeeping is needed. SmallVector> rewrites; /// A set of operations that should no longer be considered for legalization. @@ -1089,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector patternInsertedBlocks; + /// A list of unresolved materializations that were created by the current + /// pattern. + DenseSet patternMaterializations; + /// A mapping for looking up metadata of unresolved materializations. DenseMap unresolvedMaterializations; @@ -1104,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// A set of erased operations. This set is utilized only if + /// `allowPatternRollback` is set to "false". Conceptually, this set is + /// simialar to `replacedOps` (which is maintained when the flag is set to + /// "true"). However, erasing from a DenseSet is more efficient than erasing + /// from a SetVector. + DenseSet erasedOps; + + /// A set of erased blocks. This set is utilized only if + /// `allowPatternRollback` is set to "false". + DenseSet erasedBlocks; + + /// A rewriter that notifies the listener (if any) about all IR + /// modifications. This rewriter is utilized only if `allowPatternRollback` + /// is set to "false". If the flag is set to "true", the listener is notified + /// with a separate mechanism (e.g., in `IRRewrite::commit`). + IRRewriter notifyingRewriter; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1140,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); - if (!repl) - return; - +static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, + Value repl) { if (isa(repl)) { rewriter.replaceAllUsesWith(arg, repl); return; @@ -1161,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) + return; + performReplaceBlockArg(rewriter, arg, repl); +} + void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { @@ -1246,6 +1277,30 @@ void ConversionPatternRewriterImpl::applyRewrites() { ValueVector ConversionPatternRewriterImpl::lookupOrDefault( Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + // Helper function that looks up a single value. + auto lookup = [&](const ValueVector &values) -> ValueVector { + assert(!values.empty() && "expected non-empty value vector"); + + // If the pattern rollback is enabled, use the mapping to look up the + // values. + if (config.allowPatternRollback) + return mapping.lookup(values); + + // Otherwise, look up values by examining the IR. All replacements have + // already been materialized in IR. + Operation *op = getCommonDefiningOp(values); + if (!op) + return {}; + auto castOp = dyn_cast(op); + if (!castOp) + return {}; + if (!this->unresolvedMaterializations.contains(castOp)) + return {}; + if (castOp.getOutputs() != values) + return {}; + return castOp.getInputs(); + }; + // Helper function that looks up each value in `values` individually and then // composes the results. If that fails, it tries to look up the entire vector // at once. @@ -1253,7 +1308,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( // If possible, replace each value with (one or multiple) mapped values. ValueVector next; for (Value v : values) { - ValueVector r = mapping.lookup({v}); + ValueVector r = lookup({v}); if (!r.empty()) { llvm::append_range(next, r); } else { @@ -1273,7 +1328,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( // be stored (and looked up) in the mapping. But for performance reasons, // we choose to reuse existing IR (when possible) instead of creating it // multiple times. - ValueVector r = mapping.lookup(values); + ValueVector r = lookup(values); if (r.empty()) { // No mapping found: The lookup stops here. return {}; @@ -1347,15 +1402,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state, void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { - if (!config.allowPatternRollback && - !isa(rewrite)) { - // Unresolved materializations can always be rolled back (erased). - llvm::report_fatal_error("pattern '" + patternName + - "' rollback of IR modifications requested"); - } + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) rewrite->rollback(); - } rewrites.resize(numRewritesToKeep); } @@ -1419,12 +1467,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + return wasOpReplaced(op) || ignoredOps.count(op); } bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Check to see if this operation was replaced. - return replacedOps.count(op); + return replacedOps.count(op) || erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1508,7 +1556,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // a bit more efficient, so we try to do that when possible. bool fastPath = !config.listener; if (fastPath) { - appendRewrite(newBlock, block, newBlock->end()); + if (config.allowPatternRollback) + appendRewrite(newBlock, block, newBlock->end()); newBlock->getOperations().splice(newBlock->end(), block->getOperations()); } else { while (!block->empty()) @@ -1556,7 +1605,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( replaceUsesOfBlockArgument(origArg, replArgs, converter); } - appendRewrite(/*origBlock=*/block, newBlock); + if (config.allowPatternRollback) + appendRewrite(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1585,23 +1635,32 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // tracking the materialization like we do for other operations. OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - auto convertOp = + UnrealizedConversionCastOp convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); if (isPureTypeConversion) convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); - if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); + + // Register the materialization. if (castOp) *castOp = convertOp; unresolvedMaterializations[convertOp] = UnresolvedMaterializationInfo(converter, kind, originalType); - appendRewrite(convertOp, - std::move(valuesToMap)); + if (config.allowPatternRollback) { + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); + appendRewrite(convertOp, + std::move(valuesToMap)); + } else { + patternMaterializations.insert(convertOp); + } return convertOp.getResults(); } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { + assert(config.allowPatternRollback && + "this code path is valid only in rollback mode"); + // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. @@ -1663,26 +1722,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) && "attempting to insert into a block within a replaced/erased op"); + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyOperationInserted(op, previous); + if (wasDetached) { - // If the op was detached, it is most likely a newly created op. - // TODO: If the same op is inserted multiple times from a detached state, - // the rollback mechanism may erase the same op multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite(op); + // If the op was detached, it is most likely a newly created op. Add it the + // set of newly created ops, so that it will be legalized. If this op is + // not a newly created op, it will be legalized a second time, which is + // inefficient but harmless. patternNewOps.insert(op); + + if (config.allowPatternRollback) { + // TODO: If the same op is inserted multiple times from a detached + // state, the rollback mechanism may erase the same op multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite(op); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased operations that must be kept up to date. + erasedOps.erase(op); + } return; } // The op was moved from one place to another. - appendRewrite(op, previous); + if (config.allowPatternRollback) + appendRewrite(op, previous); +} + +/// Given that `fromRange` is about to be replaced with `toRange`, compute +/// replacement values with the types of `fromRange`. +static SmallVector +getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, + const SmallVector> &toRange, + const TypeConverter *converter) { + assert(!impl.config.allowPatternRollback && + "this code path is valid only in 'no rollback' mode"); + SmallVector repls; + for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) { + if (from.use_empty()) { + // The replaced value is dead. No replacement value is needed. + repls.push_back(Value()); + continue; + } + + if (to.empty()) { + // The replaced value is dropped. Materialize a replacement value "out of + // thin air". + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/from.getType(), /*originalType=*/Type(), + converter)[0]; + repls.push_back(srcMat); + continue; + } + + if (TypeRange(to) == TypeRange(from.getType())) { + // The replacement value already has the correct type. Use it directly. + repls.push_back(to[0]); + continue; + } + + // The replacement value has the wrong type. Build a source materialization + // to the original type. + // TODO: This is a bit inefficient. We should try to reuse existing + // materializations if possible. This would require an extension of the + // `lookupOrDefault` API. + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), + /*originalType=*/Type(), converter)[0]; + repls.push_back(srcMat); + } + + return repls; } void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector> &&newValues) { - assert(newValues.size() == op->getNumResults()); + assert(newValues.size() == op->getNumResults() && + "incorrect number of replacement values"); + + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + SmallVector repls = getReplacementValues( + *this, op->getResults(), newValues, currentTypeConverter); + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + op->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + op->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Replace the op with the replacement values and notify the listener. + notifyingRewriter.replaceOp(op, repls); + return; + } + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Check if replaced op is an unresolved materialization, i.e., an @@ -1722,11 +1874,46 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( BlockArgument from, ValueRange to, const TypeConverter *converter) { + if (!config.allowPatternRollback) { + SmallVector toConv = llvm::to_vector(to); + SmallVector repls = + getReplacementValues(*this, from, {toConv}, converter); + IRRewriter r(from.getContext()); + Value repl = repls.front(); + if (!repl) + return; + + performReplaceBlockArg(r, from, repl); + return; + } + appendRewrite(from.getOwner(), from, converter); mapping.map(from, to); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + block->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + block->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Erase the block and notify the listener. + notifyingRewriter.eraseBlock(block); + return; + } + assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); appendRewrite(block); @@ -1760,23 +1947,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(newParentOp) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) && "attempting to insert into a region within a replaced/erased op"); (void)newParentOp; + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyBlockInserted(block, previous, previousIt); + patternInsertedBlocks.insert(block); if (wasDetached) { // If the block was detached, it is most likely a newly created block. - // TODO: If the same block is inserted multiple times from a detached state, - // the rollback mechanism may erase the same block multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite(block); + if (config.allowPatternRollback) { + // TODO: If the same block is inserted multiple times from a detached + // state, the rollback mechanism may erase the same block multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite(block); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased blocks that must be kept up to date. + erasedBlocks.erase(block); + } return; } // The block was moved from one place to another. - appendRewrite(block, previous, previousIt); + if (config.allowPatternRollback) + appendRewrite(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1956,7 +2157,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // a bit more efficient, so we try to do that when possible. bool fastPath = !getConfig().listener; - if (fastPath) + if (fastPath && impl->config.allowPatternRollback) impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. @@ -1982,6 +2183,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + // Pattern rollback is not allowed: no extra bookkeeping is needed. + PatternRewriter::startOpModification(op); + return; + } assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); #ifndef NDEBUG @@ -1991,20 +2197,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); - PatternRewriter::finalizeOpModification(op); impl->patternModifiedOps.insert(op); + if (!impl->config.allowPatternRollback) { + PatternRewriter::finalizeOpModification(op); + if (getConfig().listener) + getConfig().listener->notifyOperationModified(op); + return; + } // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } void ConversionPatternRewriter::cancelOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + PatternRewriter::cancelOpModification(op); + return; + } #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -2425,17 +2640,35 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (!rewriterImpl.config.allowPatternRollback) { - // Returning "failure" after modifying IR is not allowed. + // Erase all unresolved materializations. + for (auto op : rewriterImpl.patternMaterializations) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + rewriterImpl.patternMaterializations.clear(); +#if 0 + // Cheap pattern check that could have false positives. Can be enabled + // manually for debugging purposes. E.g., this check would report an API + // violation when an op is created and then erased in the same pattern. + if (!rewriterImpl.patternNewOps.empty() || + !rewriterImpl.patternModifiedOps.empty() || + !rewriterImpl.patternInsertedBlocks.empty()) { + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' rollback of IR modifications requested"); + } +#endif +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Expensive pattern check that can detect more API violations and has no + // fewer false positives than the cheap check. if (checkOp) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + "' returned failure but IR did change"); } - } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2459,6 +2692,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + if (!rewriterImpl.config.allowPatternRollback) { + // Eagerly erase unused materializations. + for (auto op : rewriterImpl.patternMaterializations) { + if (op->use_empty()) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + } + rewriterImpl.patternMaterializations.clear(); + } SetVector newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); @@ -2549,6 +2792,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. for (Block *block : insertedBlocks) { + if (impl.erasedBlocks.contains(block)) + continue; + // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 83bdbe1f67118..ba12ff29ebef9 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -3,6 +3,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=arith" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // CHECK-LABEL: @vector_ops func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>, %arg3: vector<4xi64>) -> vector<4xf32> { @@ -373,12 +374,11 @@ func.func @integer_extension_and_truncation(%arg0 : i3) { // CHECK-LABEL: @integer_cast_0d_vector func.func @integer_cast_0d_vector(%arg0 : vector) { -// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast -// CHECK-NEXT: = llvm.sext %[[ARG0]] : vector<1xi3> to vector<1xi6> +// CHECK: = llvm.sext %{{.*}}: vector<1xi3> to vector<1xi6> %0 = arith.extsi %arg0 : vector to vector -// CHECK-NEXT: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6> +// CHECK-NEXT: = llvm.zext %{{.*}} : vector<1xi3> to vector<1xi6> %1 = arith.extui %arg0 : vector to vector -// CHECK-NEXT: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2> +// CHECK-NEXT: = llvm.trunc %{{.*}} : vector<1xi3> to vector<1xi2> %2 = arith.trunci %arg0 : vector to vector return } diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir index ad1b6658fbe78..4d2c12a56eaca 100644 --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -3,6 +3,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=complex" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=complex allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // CHECK-LABEL: func @complex_create // CHECK-SAME: (%[[REAL0:.*]]: f32, %[[IMAG0:.*]]: f32) @@ -23,9 +24,9 @@ func.func @complex_constant() -> complex { // CHECK-LABEL: func @complex_extract // CHECK-SAME: (%[[CPLX:.*]]: complex) -// CHECK-NEXT: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex to !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)> +// CHECK: builtin.unrealized_conversion_cast %[[CPLX]] : complex to !llvm.struct<(f32, f32)> +// CHECK: %[[REAL:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(f32, f32)> +// CHECK: %[[IMAG:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(f32, f32)> func.func @complex_extract(%cplx: complex) { %real1 = complex.re %cplx : complex %imag1 = complex.im %cplx : complex diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir index 3ec8f1fa1e567..18d0526ecf1a9 100644 --- a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir +++ b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir @@ -3,6 +3,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=cf" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=cf allow-pattern-rollback=0" --split-input-file %s | FileCheck %s func.func @main() { %a = arith.constant 0 : i1 diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir index 2113557fbbb15..94dfceadbc449 100644 --- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir @@ -9,6 +9,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=arith,cf,func,math" %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=arith,cf,func,math allow-pattern-rollback=0" %s | FileCheck %s // CHECK-LABEL: func @empty() { // CHECK-NEXT: llvm.return diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir index ed7fa6508d5ad..0016db5084584 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true}))" | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true allow-pattern-rollback=0}))" | FileCheck %s // CHECK-LABEL: gpu.module @nvvm_module gpu.module @nvvm_module [#nvvm.target] { diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir index 26abb3bdc23a1..007929ed677fa 100644 --- a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir +++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir @@ -5,6 +5,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=index" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=index allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // CHECK-LABEL: @trivial_ops func.func @trivial_ops(%a: index, %b: index) { diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index 92904082a6f46..f4541220fe4d2 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -3,6 +3,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=math" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=math allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // CHECK-LABEL: @ops func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) { diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index e50576722e38c..24873340d7122 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -3,6 +3,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // CHECK-LABEL: @init_mbarrier llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) { diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index d69de998346b5..7d8ccd910cdf4 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -convert-openmp-to-llvm -split-input-file %s | FileCheck %s // RUN: mlir-opt -convert-to-llvm -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-to-llvm="allow-pattern-rollback=0" -split-input-file %s | FileCheck %s // CHECK-LABEL: llvm.func @foo(i64, i64) func.func private @foo(index, index) diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir index 5307e477b8786..6c0b111d4c2c5 100644 --- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir +++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir @@ -3,6 +3,7 @@ // Same below, but using the `ConvertToLLVMPatternInterface` entry point // and the generic `convert-to-llvm` pass. // RUN: mlir-opt --convert-to-llvm="filter-dialects=ub" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=ub allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // CHECK-LABEL: @check_poison func.func @check_poison() { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 5a424a8ac0d5f..9b57b1b6fb4c7 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --convert-to-llvm="filter-dialects=vector" --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-to-llvm="filter-dialects=vector allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s //===========================================================================// @@ -182,8 +183,7 @@ func.func @shuffle_0D_direct(%arg0: vector) -> vector<3xf32> { } // CHECK-LABEL: @shuffle_0D_direct( // CHECK-SAME: %[[A:.*]]: vector -// CHECK: %[[c:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> -// CHECK: %[[s:.*]] = llvm.shufflevector %[[c]], %[[c]] [0, 1, 0] : vector<1xf32> +// CHECK: %[[s:.*]] = llvm.shufflevector %{{.*}}, %{{.*}} [0, 1, 0] : vector<1xf32> // CHECK: return %[[s]] : vector<3xf32> // ----- diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir index 74931cb0830bc..5c29b04630cad 100644 --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tenso } // CHECK-LABEL: func @detensor_op_sequence // CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) -// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]] // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] -// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] -// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] +// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]] +// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] // CHECK: return %[[new_tensor_res]] diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir index 8f74976c59773..25a338df8d790 100644 --- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @main() { // This buffer is properly aligned. There should be no error. // CHECK-NOT: ^ memref is not aligned to 8 diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir index 26c731c921356..4c6a48d577a6c 100644 --- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -5,6 +5,14 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @store_dynamic(%memref: memref, %index: index) { %cst = arith.constant 1.0 : f32 memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref) -> f32 diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir index 8b6308e9c1939..1ac10306395ad 100644 --- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir @@ -1,11 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @cast_to_static_dim(%m: memref) -> memref<10xf32> { %0 = memref.cast %m : memref to memref<10xf32> return %0 : memref<10xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir index 95b9db2832cee..be9417baf93df 100644 --- a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + // Put memref.copy in a function, otherwise the memref.cast may fold. func.func @memcpy_helper(%src: memref, %dest: memref) { memref.copy %src, %dest : memref to memref diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir index 2e3f271743c93..ef4af62459738 100644 --- a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @main() { %c4 = arith.constant 4 : index %alloca = memref.alloca() : memref<1xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir index b87e5bdf0970c..2e42648297875 100644 --- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir @@ -1,12 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ -// RUN: -lower-affine \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @load(%memref: memref<1xf32>, %index: index) { memref.load %memref[%index] : memref<1xf32> return diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir index 12253fa3b5e83..dd000c6904bcb 100644 --- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -5,6 +5,14 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @store_dynamic(%memref: memref, %index: index) { %cst = arith.constant 1.0 : f32 memref.store %cst, %memref[%index] : memref diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index ec7e4085f2fa5..9fbe5bc60321e 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -1,12 +1,22 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ // RUN: -lower-affine \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @subview(%memref: memref<1xf32>, %offset: index) { memref.subview %memref[%offset] [1] [1] : memref<1xf32> to diff --git a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir index e4aab32d4a390..f37a6d6383c48 100644 --- a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func private @cast_to_static_dim(%t: tensor) -> tensor<10xf32> { %0 = tensor.cast %t : tensor to tensor<10xf32> return %0 : tensor<10xf32> diff --git a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir index c6d8f698b9433..e9e5c040c6488 100644 --- a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir @@ -1,10 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -one-shot-bufferize \ -// RUN: -buffer-deallocation-pipeline \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ // RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s func.func @main() { diff --git a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir index 8e3cab7be704d..73fcec4d7abcd 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @extract(%tensor: tensor<1xf32>, %index: index) { tensor.extract %tensor[%index] : tensor<1xf32> return diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index 28f9be0fffe64..341a59e8b8102 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @extract_slice(%tensor: tensor<1xf32>, %offset: index) { tensor.extract_slice %tensor[%offset] [1] [1] : tensor<1xf32> to tensor<1xf32> return diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 5630d1540e4d5..9a04da7904863 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,9 +1,14 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B" +// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B" +// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "B" +// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "E" +// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "E" +// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "E" -// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "B" -// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "B" -// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "B" -// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "E" // Note: Listener notifications appear after the pattern application because // the conversion driver sends all notifications at the end of the conversion // in bulk. @@ -11,8 +16,6 @@ // CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a // CHECK-NEXT: notifyOperationModified: func.return // CHECK-NEXT: notifyOperationErased: test.illegal_op_a -// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "E" -// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "E" // CHECK-LABEL: verifyDirectPattern func.func @verifyDirectPattern() -> i32 { // CHECK-NEXT: "test.legal_op_a"() <{status = "Success"} @@ -29,7 +32,9 @@ func.func @verifyDirectPattern() -> i32 { // CHECK-NEXT: notifyOperationErased: test.illegal_op_c // CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked // CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e -// CHECK-NEXT: notifyOperationErased: test.illegal_op_e +// Note: func.return is modified a second time when running in no-rollback +// mode. +// CHECK: notifyOperationErased: test.illegal_op_e // CHECK-LABEL: verifyLargerBenefit func.func @verifyLargerBenefit() -> i32 { @@ -70,7 +75,7 @@ func.func @remap_call_1_to_1(%arg0: i64) { // CHECK: notifyBlockInserted into func.func: was unlinked // Contents of the old block are moved to the new block. -// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown +// CHECK-NEXT: notifyOperationInserted: test.return // The old block is erased. // CHECK-NEXT: notifyBlockErased @@ -409,8 +414,10 @@ func.func @test_remap_block_arg() { // CHECK-LABEL: func @test_multiple_1_to_n_replacement() // CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16) -// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16 -// CHECK: "test.valid"(%[[cast]]) : (f16) -> () +// Note: There is a bug in the rollback-based conversion driver: it emits a +// "test.cast" : (f16, f16, f16, f16) -> f16, when it should be emitting +// three consecutive casts of (f16, f16) -> f16. +// CHECK: "test.valid"(%{{.*}}) : (f16) -> () func.func @test_multiple_1_to_n_replacement() { %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) "test.invalid"(%0) : (f16) -> () diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f92f0982f85b2..bc865b23aa4d8 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1177,8 +1177,8 @@ struct TestNonRootReplacement : public RewritePattern { auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType); auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType); - rewriter.replaceOp(illegalOp, legalOp); rewriter.replaceOp(op, illegalOp); + rewriter.replaceOp(illegalOp, legalOp); return success(); } }; @@ -1362,6 +1362,7 @@ class TestMultiple1ToNReplacement : public ConversionPattern { // Helper function that replaces the given op with a new op of the given // name and doubles each result (1 -> 2 replacement of each result). auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { + rewriter.setInsertionPointAfter(op); SmallVector types; for (Type t : op->getResultTypes()) { types.push_back(t); @@ -1560,6 +1561,7 @@ struct TestLegalizePatternDriver if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; config.unlegalizedOps = &unlegalizedOps; @@ -1581,6 +1583,7 @@ struct TestLegalizePatternDriver }); ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; if (failed(applyFullConversion(getOperation(), target, @@ -1596,6 +1599,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet legalizedOps; ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, std::move(patterns), config))) @@ -1616,6 +1620,10 @@ struct TestLegalizePatternDriver clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), clEnumValN(ConversionMode::Partial, "partial", "Perform a partial conversion"))}; + + Option allowPatternRollback{*this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback"), + llvm::cl::init(true)}; }; } // namespace