-
Notifications
You must be signed in to change notification settings - Fork 15.4k
Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase" #117094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
matthias-springer
merged 1 commit into
main
from
revert-116934-users/matthias-springer/no_inverse_mapping
Nov 21, 2024
Merged
Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase" #117094
matthias-springer
merged 1 commit into
main
from
revert-116934-users/matthias-springer/no_inverse_mapping
Nov 21, 2024
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
… phase (…" This reverts commit aa65473.
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesReverts llvm/llvm-project#116934 This commit broke the build. Full diff: https://github.com/llvm/llvm-project/pull/117094.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 03d483f73f255e..42fe5b925654a1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -75,10 +75,6 @@ namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
- /// Return "true" if an SSA value is mapped to the given value. May return
- /// false positives.
- bool isMappedTo(Value value) const { return mappedTo.contains(value); }
-
/// Lookup the most recently mapped value with the desired type in the
/// mapping.
///
@@ -103,18 +99,22 @@ struct ConversionValueMapping {
assert(it != oldVal && "inserting cyclic mapping");
});
mapping.map(oldVal, newVal);
- mappedTo.insert(newVal);
}
/// Drop the last mapping for the given value.
void erase(Value value) { mapping.erase(value); }
+ /// Returns the inverse raw value mapping (without recursive query support).
+ DenseMap<Value, SmallVector<Value>> getInverse() const {
+ DenseMap<Value, SmallVector<Value>> inverse;
+ for (auto &it : mapping.getValueMap())
+ inverse[it.second].push_back(it.first);
+ return inverse;
+ }
+
private:
/// Current value mappings.
IRMapping mapping;
-
- /// All SSA values that are mapped to. May contain false positives.
- DenseSet<Value> mappedTo;
};
} // namespace
@@ -434,9 +434,10 @@ class MoveBlockRewrite : public BlockRewrite {
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, Block *origBlock)
+ Block *block, Block *origBlock,
+ const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
- origBlock(origBlock) {}
+ origBlock(origBlock), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -444,6 +445,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
Block *getOrigBlock() const { return origBlock; }
+ const TypeConverter *getConverter() const { return converter; }
+
void commit(RewriterBase &rewriter) override;
void rollback() override;
@@ -451,6 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
private:
/// The original block that was requested to have its signature converted.
Block *origBlock;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter *converter;
};
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -459,10 +465,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
class ReplaceBlockArgRewrite : public BlockRewrite {
public:
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, BlockArgument arg,
- const TypeConverter *converter)
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
- converter(converter) {}
+ Block *block, BlockArgument arg)
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::ReplaceBlockArg;
@@ -474,9 +478,6 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
private:
BlockArgument arg;
-
- /// The current type converter when the block argument was replaced.
- const TypeConverter *converter;
};
/// An operation rewrite.
@@ -626,6 +627,8 @@ class ReplaceOperationRewrite : public OperationRewrite {
void cleanup(RewriterBase &rewriter) override;
+ const TypeConverter *getConverter() const { return converter; }
+
private:
/// An optional type converter that can be used to materialize conversions
/// between the new and old values if necessary.
@@ -822,14 +825,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
- /// Find a replacement value for the given SSA value in the conversion value
- /// mapping. The replacement value must have the same type as the given SSA
- /// value. If there is no replacement value with the correct type, find the
- /// latest replacement value (regardless of the type) and build a source
- /// materialization.
- Value findOrBuildReplacementValue(Value value,
- const TypeConverter *converter);
-
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -975,7 +970,7 @@ void BlockTypeConversionRewrite::rollback() {
}
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
return;
@@ -1004,7 +999,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
// Compute replacement values.
SmallVector<Value> replacements =
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
- return rewriterImpl.findOrBuildReplacementValue(result, converter);
+ return rewriterImpl.mapping.lookupOrNull(result, result.getType());
});
// Notify the listener that the operation is about to be replaced.
@@ -1074,10 +1069,8 @@ void UnresolvedMaterializationRewrite::rollback() {
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
IRRewriter rewriter(context, config.listener);
- // Note: New rewrites may be added during the "commit" phase and the
- // `rewrites` vector may reallocate.
- for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ for (auto &rewrite : rewrites)
+ rewrite->commit(rewriter);
// Clean up all rewrites.
for (auto &rewrite : rewrites)
@@ -1282,7 +1275,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1292,7 +1285,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, repl);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1305,10 +1298,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
insertNTo1Materialization(
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
}
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1378,41 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
}
}
-Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
- Value value, const TypeConverter *converter) {
- // Find a replacement value with the same type.
- Value repl = mapping.lookupOrNull(value, value.getType());
- if (repl)
- return repl;
-
- // Check if the value is dead. No replacement value is needed in that case.
- // This is an approximate check that may have false negatives but does not
- // require computing and traversing an inverse mapping. (We may end up
- // building source materializations that are never used and that fold away.)
- if (llvm::all_of(value.getUsers(),
- [&](Operation *op) { return replacedOps.contains(op); }) &&
- !mapping.isMappedTo(value))
- return Value();
-
- // No replacement value was found. Get the latest replacement value
- // (regardless of the type) and build a source materialization to the
- // original type.
- repl = mapping.lookupOrNull(value);
- if (!repl) {
- // No replacement value is registered in the mapping. This means that the
- // value is dropped and no longer needed. (If the value were still needed,
- // a source materialization producing a replacement value "out of thin air"
- // would have already been created during `replaceOp` or
- // `applySignatureConversion`.)
- return Value();
- }
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*inputs=*/repl, /*outputType=*/value.getType(),
- /*originalType=*/Type(), converter);
- return castValue;
-}
-
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1639,8 +1597,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
- impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
- impl->currentTypeConverter);
+ impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
@@ -2460,6 +2417,10 @@ struct OperationConverter {
/// Converts an operation with the given rewriter.
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+ /// This method is called after the conversion process to legalize any
+ /// remaining artifacts and complete the conversion.
+ void finalize(ConversionPatternRewriter &rewriter);
+
/// Dialect conversion configuration.
ConversionConfig config;
@@ -2580,6 +2541,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (failed(convert(rewriter, op)))
return rewriterImpl.undoRewrites(), failure();
+ // Now that all of the operations have been converted, finalize the conversion
+ // process to ensure any lingering conversion artifacts are cleaned up and
+ // legalized.
+ finalize(rewriter);
+
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
@@ -2613,6 +2579,80 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}
+/// Finds a user of the given value, or of any other value that the given value
+/// replaced, that was not replaced in the conversion process.
+static Operation *findLiveUserOfReplaced(
+ Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+ SmallVector<Value> worklist = {initialValue};
+ while (!worklist.empty()) {
+ Value value = worklist.pop_back_val();
+
+ // Walk the users of this value to see if there are any live users that
+ // weren't replaced during conversion.
+ auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
+ return rewriterImpl.isOpIgnored(user);
+ });
+ if (liveUserIt != value.user_end())
+ return *liveUserIt;
+ auto mapIt = inverseMapping.find(value);
+ if (mapIt != inverseMapping.end())
+ worklist.append(mapIt->second);
+ }
+ return nullptr;
+}
+
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+ if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+ return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
+ if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+ return {blockRewrite->getOrigBlock()->getArguments(),
+ blockRewrite->getConverter()};
+ return {};
+}
+
+void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ DenseMap<Value, SmallVector<Value>> inverseMapping =
+ rewriterImpl.mapping.getInverse();
+
+ // Process requested value replacements.
+ for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+ ValueRange replacedValues;
+ const TypeConverter *converter;
+ std::tie(replacedValues, converter) =
+ getReplacedValues(rewriterImpl.rewrites[i].get());
+ for (Value originalValue : replacedValues) {
+ // If the type of this value changed and the value is still live, we need
+ // to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(originalValue,
+ originalValue.getType()))
+ continue;
+ Operation *liveUser =
+ findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
+ if (!liveUser)
+ continue;
+
+ // Legalize this value replacement.
+ Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
+ assert(newValue && "replacement value not found");
+ Value castValue = rewriterImpl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(newValue),
+ originalValue.getLoc(),
+ /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+ /*originalType=*/Type(), converter);
+ rewriterImpl.mapping.map(originalValue, castValue);
+ inverseMapping[castValue].push_back(originalValue);
+ llvm::erase(inverseMapping[newValue], originalValue);
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reverts #116934
This commit broke the build.