Skip to content

Commit 96a3a58

Browse files
[mlir][Transforms] Simplify ConversionPatternRewriter::replaceOp implementation (#158075)
Move the logic for building "out-of-thin-air" source materializations during op replacements from `replaceOp` to `findOrBuildReplacementValue`. That function already builds source materializations and can handle the case where an op result is dropped. This commit is in preparation of turning `replaceOp` into a non-virtual function. (It is sufficient for `replaceAllUsesWith` and `eraseOp` to be virtual.)
1 parent ec5460b commit 96a3a58

File tree

1 file changed

+24
-35
lines changed

1 file changed

+24
-35
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
16181618
if (!inputMap) {
16191619
// This block argument was dropped and no replacement value was provided.
16201620
// Materialize a replacement value "out of thin air".
1621+
// Note: Materialization must be built here because we cannot find a
1622+
// valid insertion point in the new block. (Will point to the old block.)
16211623
Value mat =
16221624
buildUnresolvedMaterialization(
16231625
MaterializationKind::Source,
@@ -1725,29 +1727,29 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
17251727
// (regardless of the type) and build a source materialization to the
17261728
// original type.
17271729
repl = lookupOrNull(value);
1730+
1731+
// Compute the insertion point of the materialization.
1732+
OpBuilder::InsertPoint ip;
17281733
if (repl.empty()) {
1729-
// No replacement value is registered in the mapping. This means that the
1730-
// value is dropped and no longer needed. (If the value were still needed,
1731-
// a source materialization producing a replacement value "out of thin air"
1732-
// would have already been created during `replaceOp` or
1733-
// `applySignatureConversion`.)
1734-
return Value();
1734+
// The source materialization has no inputs. Insert it right before the
1735+
// value that it is replacing.
1736+
ip = computeInsertPoint(value);
1737+
} else {
1738+
// Compute the "earliest" insertion point at which all values in `repl` are
1739+
// defined. It is important to emit the materialization at that location
1740+
// because the same materialization may be reused in a different context.
1741+
// (That's because materializations are cached in the conversion value
1742+
// mapping.) The insertion point of the materialization must be valid for
1743+
// all future users that may be created later in the conversion process.
1744+
ip = computeInsertPoint(repl);
17351745
}
1736-
1737-
// Note: `computeInsertPoint` computes the "earliest" insertion point at
1738-
// which all values in `repl` are defined. It is important to emit the
1739-
// materialization at that location because the same materialization may be
1740-
// reused in a different context. (That's because materializations are cached
1741-
// in the conversion value mapping.) The insertion point of the
1742-
// materialization must be valid for all future users that may be created
1743-
// later in the conversion process.
1744-
Value castValue =
1745-
buildUnresolvedMaterialization(MaterializationKind::Source,
1746-
computeInsertPoint(repl), value.getLoc(),
1747-
/*valuesToMap=*/repl, /*inputs=*/repl,
1748-
/*outputTypes=*/value.getType(),
1749-
/*originalType=*/Type(), converter)
1750-
.front();
1746+
Value castValue = buildUnresolvedMaterialization(
1747+
MaterializationKind::Source, ip, value.getLoc(),
1748+
/*valuesToMap=*/repl, /*inputs=*/repl,
1749+
/*outputTypes=*/value.getType(),
1750+
/*originalType=*/Type(), converter,
1751+
/*isPureTypeConversion=*/!repl.empty())
1752+
.front();
17511753
return castValue;
17521754
}
17531755

@@ -1897,21 +1899,8 @@ void ConversionPatternRewriterImpl::replaceOp(
18971899
}
18981900

18991901
// Create mappings for each of the new result values.
1900-
for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
1901-
if (repl.empty()) {
1902-
// This result was dropped and no replacement value was provided.
1903-
// Materialize a replacement value "out of thin air".
1904-
buildUnresolvedMaterialization(
1905-
MaterializationKind::Source, computeInsertPoint(result),
1906-
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
1907-
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
1908-
currentTypeConverter, /*isPureTypeConversion=*/false);
1909-
continue;
1910-
}
1911-
1912-
// Remap result to replacement value.
1902+
for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
19131903
mapping.map(static_cast<Value>(result), std::move(repl));
1914-
}
19151904

19161905
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
19171906
// Mark this operation and all nested ops as replaced.

0 commit comments

Comments
 (0)