Skip to content

Commit dc7f02e

Browse files
[mlir][Transforms] Dialect conversion: Remove "finalize" phase
1 parent 04de524 commit dc7f02e

File tree

1 file changed

+72
-112
lines changed

1 file changed

+72
-112
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 72 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ namespace {
7575
/// This class wraps a IRMapping to provide recursive lookup
7676
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
7777
struct ConversionValueMapping {
78+
/// Return "true" if an SSA value is mapped to the given value. May return
79+
/// false positives.
80+
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
81+
7882
/// Lookup the most recently mapped value with the desired type in the
7983
/// mapping.
8084
///
@@ -99,22 +103,18 @@ struct ConversionValueMapping {
99103
assert(it != oldVal && "inserting cyclic mapping");
100104
});
101105
mapping.map(oldVal, newVal);
106+
mappedTo.insert(newVal);
102107
}
103108

104109
/// Drop the last mapping for the given value.
105110
void erase(Value value) { mapping.erase(value); }
106111

107-
/// Returns the inverse raw value mapping (without recursive query support).
108-
DenseMap<Value, SmallVector<Value>> getInverse() const {
109-
DenseMap<Value, SmallVector<Value>> inverse;
110-
for (auto &it : mapping.getValueMap())
111-
inverse[it.second].push_back(it.first);
112-
return inverse;
113-
}
114-
115112
private:
116113
/// Current value mappings.
117114
IRMapping mapping;
115+
116+
/// All SSA values that are mapped to. May contain false positives.
117+
DenseSet<Value> mappedTo;
118118
};
119119
} // namespace
120120

@@ -434,29 +434,23 @@ class MoveBlockRewrite : public BlockRewrite {
434434
class BlockTypeConversionRewrite : public BlockRewrite {
435435
public:
436436
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
437-
Block *block, Block *origBlock,
438-
const TypeConverter *converter)
437+
Block *block, Block *origBlock)
439438
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
440-
origBlock(origBlock), converter(converter) {}
439+
origBlock(origBlock) {}
441440

442441
static bool classof(const IRRewrite *rewrite) {
443442
return rewrite->getKind() == Kind::BlockTypeConversion;
444443
}
445444

446445
Block *getOrigBlock() const { return origBlock; }
447446

448-
const TypeConverter *getConverter() const { return converter; }
449-
450447
void commit(RewriterBase &rewriter) override;
451448

452449
void rollback() override;
453450

454451
private:
455452
/// The original block that was requested to have its signature converted.
456453
Block *origBlock;
457-
458-
/// The type converter used to convert the arguments.
459-
const TypeConverter *converter;
460454
};
461455

462456
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -465,8 +459,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
465459
class ReplaceBlockArgRewrite : public BlockRewrite {
466460
public:
467461
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
468-
Block *block, BlockArgument arg)
469-
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
462+
Block *block, BlockArgument arg,
463+
const TypeConverter *converter)
464+
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
465+
converter(converter) {}
470466

471467
static bool classof(const IRRewrite *rewrite) {
472468
return rewrite->getKind() == Kind::ReplaceBlockArg;
@@ -478,6 +474,9 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
478474

479475
private:
480476
BlockArgument arg;
477+
478+
/// The current type converter when the block argument was replaced.
479+
const TypeConverter *converter;
481480
};
482481

483482
/// An operation rewrite.
@@ -627,8 +626,6 @@ class ReplaceOperationRewrite : public OperationRewrite {
627626

628627
void cleanup(RewriterBase &rewriter) override;
629628

630-
const TypeConverter *getConverter() const { return converter; }
631-
632629
private:
633630
/// An optional type converter that can be used to materialize conversions
634631
/// between the new and old values if necessary.
@@ -825,6 +822,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
825822
ValueRange replacements, Value originalValue,
826823
const TypeConverter *converter);
827824

825+
/// Find a replacement value for the given SSA value in the conversion value
826+
/// mapping. The replacement value must have the same type as the given SSA
827+
/// value. If there is no replacement value with the correct type, find the
828+
/// latest replacement value (regardless of the type) and build a source
829+
/// materialization.
830+
Value findOrBuildReplacementValue(Value value,
831+
const TypeConverter *converter);
832+
828833
//===--------------------------------------------------------------------===//
829834
// Rewriter Notification Hooks
830835
//===--------------------------------------------------------------------===//
@@ -970,7 +975,7 @@ void BlockTypeConversionRewrite::rollback() {
970975
}
971976

972977
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
973-
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
978+
Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
974979
if (!repl)
975980
return;
976981

@@ -999,7 +1004,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
9991004
// Compute replacement values.
10001005
SmallVector<Value> replacements =
10011006
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1002-
return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1007+
return rewriterImpl.findOrBuildReplacementValue(result, converter);
10031008
});
10041009

10051010
// Notify the listener that the operation is about to be replaced.
@@ -1069,8 +1074,10 @@ void UnresolvedMaterializationRewrite::rollback() {
10691074
void ConversionPatternRewriterImpl::applyRewrites() {
10701075
// Commit all rewrites.
10711076
IRRewriter rewriter(context, config.listener);
1072-
for (auto &rewrite : rewrites)
1073-
rewrite->commit(rewriter);
1077+
// Note: New rewrites may be added during the "commit" phase and the
1078+
// `rewrites` vector may reallocate.
1079+
for (size_t i = 0; i < rewrites.size(); ++i)
1080+
rewrites[i]->commit(rewriter);
10741081

10751082
// Clean up all rewrites.
10761083
for (auto &rewrite : rewrites)
@@ -1275,7 +1282,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12751282
/*inputs=*/ValueRange(),
12761283
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
12771284
mapping.map(origArg, repl);
1278-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1285+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
12791286
continue;
12801287
}
12811288

@@ -1285,7 +1292,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12851292
"invalid to provide a replacement value when the argument isn't "
12861293
"dropped");
12871294
mapping.map(origArg, repl);
1288-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1295+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
12891296
continue;
12901297
}
12911298

@@ -1298,10 +1305,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12981305
insertNTo1Materialization(
12991306
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
13001307
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1301-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1308+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13021309
}
13031310

1304-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1311+
appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
13051312

13061313
// Erase the old block. (It is just unlinked for now and will be erased during
13071314
// cleanup.)
@@ -1371,6 +1378,41 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13711378
}
13721379
}
13731380

1381+
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
1382+
Value value, const TypeConverter *converter) {
1383+
// Find a replacement value with the same type.
1384+
Value repl = mapping.lookupOrNull(value, value.getType());
1385+
if (repl)
1386+
return repl;
1387+
1388+
// Check if the value is dead. No replacement value is needed in that case.
1389+
// This is an approximate check that may have false negatives but does not
1390+
// require computing and traversing an inverse mapping. (We may end up
1391+
// building source materializations that are never used and that fold away.)
1392+
if (llvm::all_of(value.getUsers(),
1393+
[&](Operation *op) { return replacedOps.contains(op); }) &&
1394+
!mapping.isMappedTo(value))
1395+
return Value();
1396+
1397+
// No replacement value was found. Get the latest replacement value
1398+
// (regardless of the type) and build a source materialization to the
1399+
// original type.
1400+
repl = mapping.lookupOrNull(value);
1401+
if (!repl) {
1402+
// No replacement value is registered in the mapping. This means that the
1403+
// value is dropped and no longer needed. (If the value were still needed,
1404+
// a source materialization producing a replacement value "out of thin air"
1405+
// would have already been created during `replaceOp` or
1406+
// `applySignatureConversion`.)
1407+
return Value();
1408+
}
1409+
Value castValue = buildUnresolvedMaterialization(
1410+
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1411+
/*inputs=*/repl, /*outputType=*/value.getType(),
1412+
/*originalType=*/Type(), converter);
1413+
return castValue;
1414+
}
1415+
13741416
//===----------------------------------------------------------------------===//
13751417
// Rewriter Notification Hooks
13761418

@@ -1597,7 +1639,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
15971639
<< "'(in region of '" << parentOp->getName()
15981640
<< "'(" << from.getOwner()->getParentOp() << ")\n";
15991641
});
1600-
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
1642+
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
1643+
impl->currentTypeConverter);
16011644
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
16021645
}
16031646

@@ -2417,10 +2460,6 @@ struct OperationConverter {
24172460
/// Converts an operation with the given rewriter.
24182461
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
24192462

2420-
/// This method is called after the conversion process to legalize any
2421-
/// remaining artifacts and complete the conversion.
2422-
void finalize(ConversionPatternRewriter &rewriter);
2423-
24242463
/// Dialect conversion configuration.
24252464
ConversionConfig config;
24262465

@@ -2541,11 +2580,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25412580
if (failed(convert(rewriter, op)))
25422581
return rewriterImpl.undoRewrites(), failure();
25432582

2544-
// Now that all of the operations have been converted, finalize the conversion
2545-
// process to ensure any lingering conversion artifacts are cleaned up and
2546-
// legalized.
2547-
finalize(rewriter);
2548-
25492583
// After a successful conversion, apply rewrites.
25502584
rewriterImpl.applyRewrites();
25512585

@@ -2579,80 +2613,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25792613
return success();
25802614
}
25812615

2582-
/// Finds a user of the given value, or of any other value that the given value
2583-
/// replaced, that was not replaced in the conversion process.
2584-
static Operation *findLiveUserOfReplaced(
2585-
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
2586-
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2587-
SmallVector<Value> worklist = {initialValue};
2588-
while (!worklist.empty()) {
2589-
Value value = worklist.pop_back_val();
2590-
2591-
// Walk the users of this value to see if there are any live users that
2592-
// weren't replaced during conversion.
2593-
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
2594-
return rewriterImpl.isOpIgnored(user);
2595-
});
2596-
if (liveUserIt != value.user_end())
2597-
return *liveUserIt;
2598-
auto mapIt = inverseMapping.find(value);
2599-
if (mapIt != inverseMapping.end())
2600-
worklist.append(mapIt->second);
2601-
}
2602-
return nullptr;
2603-
}
2604-
2605-
/// Helper function that returns the replaced values and the type converter if
2606-
/// the given rewrite object is an "operation replacement" or a "block type
2607-
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
2608-
/// an empty ValueRange and a null type converter pointer.
2609-
static std::pair<ValueRange, const TypeConverter *>
2610-
getReplacedValues(IRRewrite *rewrite) {
2611-
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2612-
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2613-
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2614-
return {blockRewrite->getOrigBlock()->getArguments(),
2615-
blockRewrite->getConverter()};
2616-
return {};
2617-
}
2618-
2619-
void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2620-
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2621-
DenseMap<Value, SmallVector<Value>> inverseMapping =
2622-
rewriterImpl.mapping.getInverse();
2623-
2624-
// Process requested value replacements.
2625-
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2626-
ValueRange replacedValues;
2627-
const TypeConverter *converter;
2628-
std::tie(replacedValues, converter) =
2629-
getReplacedValues(rewriterImpl.rewrites[i].get());
2630-
for (Value originalValue : replacedValues) {
2631-
// If the type of this value changed and the value is still live, we need
2632-
// to materialize a conversion.
2633-
if (rewriterImpl.mapping.lookupOrNull(originalValue,
2634-
originalValue.getType()))
2635-
continue;
2636-
Operation *liveUser =
2637-
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
2638-
if (!liveUser)
2639-
continue;
2640-
2641-
// Legalize this value replacement.
2642-
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
2643-
assert(newValue && "replacement value not found");
2644-
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2645-
MaterializationKind::Source, computeInsertPoint(newValue),
2646-
originalValue.getLoc(),
2647-
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2648-
/*originalType=*/Type(), converter);
2649-
rewriterImpl.mapping.map(originalValue, castValue);
2650-
inverseMapping[castValue].push_back(originalValue);
2651-
llvm::erase(inverseMapping[newValue], originalValue);
2652-
}
2653-
}
2654-
}
2655-
26562616
//===----------------------------------------------------------------------===//
26572617
// Reconcile Unrealized Casts
26582618
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)