Skip to content

Commit bf57b8d

Browse files
[mlir][Transforms] Detect mapping overwrites during block signature conversion
Add extra assertions to make sure that a value in the conversion value mapping is not overwritten during `applySignatureConversion`.
1 parent 1102603 commit bf57b8d

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ struct ConversionValueMapping {
176176
template <typename OldVal, typename NewVal>
177177
std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
178178
map(OldVal &&oldVal, NewVal &&newVal) {
179+
assert(!mapping.contains(oldVal) &&
180+
"attempting to overwrite existing mapping");
179181
LLVM_DEBUG({
180182
ValueVector next(newVal);
181183
while (true) {
@@ -1412,6 +1414,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14121414
for (unsigned i = 0; i != origArgCount; ++i) {
14131415
BlockArgument origArg = block->getArgument(i);
14141416
Type origArgType = origArg.getType();
1417+
ValueVector currentMapping = mapping.lookupOrDefault(origArg);
14151418

14161419
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
14171420
signatureConversion.getInputMapping(i);
@@ -1421,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14211424
buildUnresolvedMaterialization(
14221425
MaterializationKind::Source,
14231426
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1424-
/*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
1427+
/*valuesToMap=*/currentMapping, /*inputs=*/ValueRange(),
14251428
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
14261429
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
14271430
continue;
@@ -1432,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14321435
assert(inputMap->size == 0 &&
14331436
"invalid to provide a replacement value when the argument isn't "
14341437
"dropped");
1435-
mapping.map(origArg, repl);
1438+
mapping.map(currentMapping, repl);
14361439
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
14371440
continue;
14381441
}
@@ -1441,7 +1444,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14411444
auto replArgs =
14421445
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
14431446
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1444-
mapping.map(origArg, std::move(replArgVals));
1447+
mapping.map(currentMapping, std::move(replArgVals));
14451448
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
14461449
}
14471450

@@ -1757,6 +1760,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
17571760
<< "'(in region of '" << parentOp->getName()
17581761
<< "'(" << from.getOwner()->getParentOp() << ")\n";
17591762
});
1763+
llvm::errs() << "replaceUsesOfBlockArgument: " << from.getOwner() << "\n";
1764+
17601765
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
17611766
impl->currentTypeConverter);
17621767
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);

0 commit comments

Comments
 (0)