diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 42fe5b925654a..03d483f73f255 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -75,6 +75,10 @@ namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { + /// Return "true" if an SSA value is mapped to the given value. May return + /// false positives. + bool isMappedTo(Value value) const { return mappedTo.contains(value); } + /// Lookup the most recently mapped value with the desired type in the /// mapping. /// @@ -99,22 +103,18 @@ struct ConversionValueMapping { assert(it != oldVal && "inserting cyclic mapping"); }); mapping.map(oldVal, newVal); + mappedTo.insert(newVal); } /// Drop the last mapping for the given value. void erase(Value value) { mapping.erase(value); } - /// Returns the inverse raw value mapping (without recursive query support). - DenseMap> getInverse() const { - DenseMap> inverse; - for (auto &it : mapping.getValueMap()) - inverse[it.second].push_back(it.first); - return inverse; - } - private: /// Current value mappings. IRMapping mapping; + + /// All SSA values that are mapped to. May contain false positives. + DenseSet mappedTo; }; } // namespace @@ -434,10 +434,9 @@ class MoveBlockRewrite : public BlockRewrite { class BlockTypeConversionRewrite : public BlockRewrite { public: BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, Block *origBlock, - const TypeConverter *converter) + Block *block, Block *origBlock) : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), - origBlock(origBlock), converter(converter) {} + origBlock(origBlock) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; @@ -445,8 +444,6 @@ class BlockTypeConversionRewrite : public BlockRewrite { Block *getOrigBlock() const { return origBlock; } - const TypeConverter *getConverter() const { return converter; } - void commit(RewriterBase &rewriter) override; void rollback() override; @@ -454,9 +451,6 @@ class BlockTypeConversionRewrite : public BlockRewrite { private: /// The original block that was requested to have its signature converted. Block *origBlock; - - /// The type converter used to convert the arguments. - const TypeConverter *converter; }; /// Replacing a block argument. This rewrite is not immediately reflected in the @@ -465,8 +459,10 @@ class BlockTypeConversionRewrite : public BlockRewrite { class ReplaceBlockArgRewrite : public BlockRewrite { public: ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, BlockArgument arg) - : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} + Block *block, BlockArgument arg, + const TypeConverter *converter) + : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), + converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::ReplaceBlockArg; @@ -478,6 +474,9 @@ class ReplaceBlockArgRewrite : public BlockRewrite { private: BlockArgument arg; + + /// The current type converter when the block argument was replaced. + const TypeConverter *converter; }; /// An operation rewrite. @@ -627,8 +626,6 @@ class ReplaceOperationRewrite : public OperationRewrite { void cleanup(RewriterBase &rewriter) override; - const TypeConverter *getConverter() const { return converter; } - private: /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. @@ -825,6 +822,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ValueRange replacements, Value originalValue, const TypeConverter *converter); + /// Find a replacement value for the given SSA value in the conversion value + /// mapping. The replacement value must have the same type as the given SSA + /// value. If there is no replacement value with the correct type, find the + /// latest replacement value (regardless of the type) and build a source + /// materialization. + Value findOrBuildReplacementValue(Value value, + const TypeConverter *converter); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -970,7 +975,7 @@ void BlockTypeConversionRewrite::rollback() { } void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); if (!repl) return; @@ -999,7 +1004,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { // Compute replacement values. SmallVector replacements = llvm::map_to_vector(op->getResults(), [&](OpResult result) { - return rewriterImpl.mapping.lookupOrNull(result, result.getType()); + return rewriterImpl.findOrBuildReplacementValue(result, converter); }); // Notify the listener that the operation is about to be replaced. @@ -1069,8 +1074,10 @@ void UnresolvedMaterializationRewrite::rollback() { void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. IRRewriter rewriter(context, config.listener); - for (auto &rewrite : rewrites) - rewrite->commit(rewriter); + // Note: New rewrites may be added during the "commit" phase and the + // `rewrites` vector may reallocate. + for (size_t i = 0; i < rewrites.size(); ++i) + rewrites[i]->commit(rewriter); // Clean up all rewrites. for (auto &rewrite : rewrites) @@ -1275,7 +1282,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); mapping.map(origArg, repl); - appendRewrite(block, origArg); + appendRewrite(block, origArg, converter); continue; } @@ -1285,7 +1292,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, repl); - appendRewrite(block, origArg); + appendRewrite(block, origArg, converter); continue; } @@ -1298,10 +1305,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( insertNTo1Materialization( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*replacements=*/replArgs, /*outputValue=*/origArg, converter); - appendRewrite(block, origArg); + appendRewrite(block, origArg, converter); } - appendRewrite(newBlock, block, converter); + appendRewrite(newBlock, block); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1371,6 +1378,41 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( } } +Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( + Value value, const TypeConverter *converter) { + // Find a replacement value with the same type. + Value repl = mapping.lookupOrNull(value, value.getType()); + if (repl) + return repl; + + // Check if the value is dead. No replacement value is needed in that case. + // This is an approximate check that may have false negatives but does not + // require computing and traversing an inverse mapping. (We may end up + // building source materializations that are never used and that fold away.) + if (llvm::all_of(value.getUsers(), + [&](Operation *op) { return replacedOps.contains(op); }) && + !mapping.isMappedTo(value)) + return Value(); + + // No replacement value was found. Get the latest replacement value + // (regardless of the type) and build a source materialization to the + // original type. + repl = mapping.lookupOrNull(value); + if (!repl) { + // No replacement value is registered in the mapping. This means that the + // value is dropped and no longer needed. (If the value were still needed, + // a source materialization producing a replacement value "out of thin air" + // would have already been created during `replaceOp` or + // `applySignatureConversion`.) + return Value(); + } + Value castValue = buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), + /*inputs=*/repl, /*outputType=*/value.getType(), + /*originalType=*/Type(), converter); + return castValue; +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1597,7 +1639,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(in region of '" << parentOp->getName() << "'(" << from.getOwner()->getParentOp() << ")\n"; }); - impl->appendRewrite(from.getOwner(), from); + impl->appendRewrite(from.getOwner(), from, + impl->currentTypeConverter); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } @@ -2417,10 +2460,6 @@ struct OperationConverter { /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); - /// This method is called after the conversion process to legalize any - /// remaining artifacts and complete the conversion. - void finalize(ConversionPatternRewriter &rewriter); - /// Dialect conversion configuration. ConversionConfig config; @@ -2541,11 +2580,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (failed(convert(rewriter, op))) return rewriterImpl.undoRewrites(), failure(); - // Now that all of the operations have been converted, finalize the conversion - // process to ensure any lingering conversion artifacts are cleaned up and - // legalized. - finalize(rewriter); - // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); @@ -2579,80 +2613,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { return success(); } -/// Finds a user of the given value, or of any other value that the given value -/// replaced, that was not replaced in the conversion process. -static Operation *findLiveUserOfReplaced( - Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, - const DenseMap> &inverseMapping) { - SmallVector worklist = {initialValue}; - while (!worklist.empty()) { - Value value = worklist.pop_back_val(); - - // Walk the users of this value to see if there are any live users that - // weren't replaced during conversion. - auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - if (liveUserIt != value.user_end()) - return *liveUserIt; - auto mapIt = inverseMapping.find(value); - if (mapIt != inverseMapping.end()) - worklist.append(mapIt->second); - } - return nullptr; -} - -/// Helper function that returns the replaced values and the type converter if -/// the given rewrite object is an "operation replacement" or a "block type -/// conversion" (which corresponds to a "block replacement"). Otherwise, return -/// an empty ValueRange and a null type converter pointer. -static std::pair -getReplacedValues(IRRewrite *rewrite) { - if (auto *opRewrite = dyn_cast(rewrite)) - return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()}; - if (auto *blockRewrite = dyn_cast(rewrite)) - return {blockRewrite->getOrigBlock()->getArguments(), - blockRewrite->getConverter()}; - return {}; -} - -void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { - ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - DenseMap> inverseMapping = - rewriterImpl.mapping.getInverse(); - - // Process requested value replacements. - for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) { - ValueRange replacedValues; - const TypeConverter *converter; - std::tie(replacedValues, converter) = - getReplacedValues(rewriterImpl.rewrites[i].get()); - for (Value originalValue : replacedValues) { - // If the type of this value changed and the value is still live, we need - // to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(originalValue, - originalValue.getType())) - continue; - Operation *liveUser = - findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping); - if (!liveUser) - continue; - - // Legalize this value replacement. - Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); - assert(newValue && "replacement value not found"); - Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(newValue), - originalValue.getLoc(), - /*inputs=*/newValue, /*outputType=*/originalValue.getType(), - /*originalType=*/Type(), converter); - rewriterImpl.mapping.map(originalValue, castValue); - inverseMapping[castValue].push_back(originalValue); - llvm::erase(inverseMapping[newValue], originalValue); - } - } -} - //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===//