Skip to content

Commit 504b507

Browse files
[mlir][Transforms] Dialect conversion: Add support for replaceUsesWithIf (#169606)
This commit adds support for `replaceUsesWithIf` (and variants such as `replaceAllUsesExcept`) to the `ConversionPatternRewriter`. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.) This commit brings us a bit closer towards removing [this](https://github.com/llvm/llvm-project/blob/76ec25f729fcc7ae576caf21293cc393e68e7cf7/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1214) workaround. Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the `ConversionValueMapping` for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later. This commit turns `replaceUsesWithIf` into a virtual function, so that the `ConversionPatternRewriter` can override it. All other API functions for conditional value replacements call that function. Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call `replaceUsesWithIf` etc. directly on the `Value` object.
1 parent bd643bc commit 504b507

File tree

5 files changed

+107
-29
lines changed

5 files changed

+107
-29
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,9 +675,9 @@ class RewriterBase : public OpBuilder {
675675
/// true. Also notify the listener about every in-place op modification (for
676676
/// every use that was replaced). The optional `allUsesReplaced` flag is set
677677
/// to "true" if all uses were replaced.
678-
void replaceUsesWithIf(Value from, Value to,
679-
function_ref<bool(OpOperand &)> functor,
680-
bool *allUsesReplaced = nullptr);
678+
virtual void replaceUsesWithIf(Value from, Value to,
679+
function_ref<bool(OpOperand &)> functor,
680+
bool *allUsesReplaced = nullptr);
681681
void replaceUsesWithIf(ValueRange from, ValueRange to,
682682
function_ref<bool(OpOperand &)> functor,
683683
bool *allUsesReplaced = nullptr);

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
903903
replaceAllUsesWith(from, ValueRange{to});
904904
}
905905

906+
/// Replace the uses of `from` with `to` for which the `functor` returns
907+
/// "true". The conversion driver will try to reconcile all type mismatches
908+
/// that still exist at the end of the conversion with materializations.
909+
/// This function supports both 1:1 and 1:N replacements.
910+
///
911+
/// Note: The functor is also applied to builtin.unrealized_conversion_cast
912+
/// ops that may have been inserted by the conversion driver. Some uses may
913+
/// have been wrapped in unrealized_conversion_cast ops due to type changes.
914+
///
915+
/// Note: This function is not supported in rollback mode. Calling it in
916+
/// rollback mode will trigger an assertion. Furthermore, the
917+
/// `allUsesReplaced` flag is not supported yet.
918+
void replaceUsesWithIf(Value from, Value to,
919+
function_ref<bool(OpOperand &)> functor,
920+
bool *allUsesReplaced = nullptr) override {
921+
replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced);
922+
}
923+
void replaceUsesWithIf(Value from, ValueRange to,
924+
function_ref<bool(OpOperand &)> functor,
925+
bool *allUsesReplaced = nullptr);
926+
906927
/// Return the converted value of 'key' with a type defined by the type
907928
/// converter of the currently executing pattern. Return nullptr in the case
908929
/// of failure, the remapped value otherwise.

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -976,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
976976
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
977977

978978
/// Replace the uses of the given value with the given values. The specified
979-
/// converter is used to build materializations (if necessary).
980-
void replaceAllUsesWith(Value from, ValueRange to,
981-
const TypeConverter *converter);
979+
/// converter is used to build materializations (if necessary). If `functor`
980+
/// is specified, only the uses that the functor returns "true" for are
981+
/// replaced.
982+
void replaceValueUses(Value from, ValueRange to,
983+
const TypeConverter *converter,
984+
function_ref<bool(OpOperand &)> functor = nullptr);
982985

983986
/// Erase the given block and its contents.
984987
void eraseBlock(Block *block);
@@ -1203,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() {
12031206
}
12041207

12051208
/// Replace all uses of `from` with `repl`.
1206-
static void performReplaceValue(RewriterBase &rewriter, Value from,
1207-
Value repl) {
1209+
static void
1210+
performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
1211+
function_ref<bool(OpOperand &)> functor = nullptr) {
12081212
if (isa<BlockArgument>(repl)) {
12091213
// `repl` is a block argument. Directly replace all uses.
1210-
rewriter.replaceAllUsesWith(from, repl);
1214+
if (functor) {
1215+
rewriter.replaceUsesWithIf(from, repl, functor);
1216+
} else {
1217+
rewriter.replaceAllUsesWith(from, repl);
1218+
}
12111219
return;
12121220
}
12131221

@@ -1238,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from,
12381246
Block *replBlock = replOp->getBlock();
12391247
rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
12401248
Operation *user = operand.getOwner();
1241-
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1249+
bool result =
1250+
user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1251+
if (result && functor)
1252+
result &= functor(operand);
1253+
return result;
12421254
});
12431255
}
12441256

@@ -1646,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
16461658
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
16471659
/*isPureTypeConversion=*/false)
16481660
.front();
1649-
replaceAllUsesWith(origArg, mat, converter);
1661+
replaceValueUses(origArg, mat, converter);
16501662
continue;
16511663
}
16521664

@@ -1655,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
16551667
assert(inputMap->size == 0 &&
16561668
"invalid to provide a replacement value when the argument isn't "
16571669
"dropped");
1658-
replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
1670+
replaceValueUses(origArg, inputMap->replacementValues, converter);
16591671
continue;
16601672
}
16611673

16621674
// This is a 1->1+ mapping.
16631675
auto replArgs =
16641676
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1665-
replaceAllUsesWith(origArg, replArgs, converter);
1677+
replaceValueUses(origArg, replArgs, converter);
16661678
}
16671679

16681680
if (config.allowPatternRollback)
@@ -1962,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp(
19621974
op->walk([&](Operation *op) { replacedOps.insert(op); });
19631975
}
19641976

1965-
void ConversionPatternRewriterImpl::replaceAllUsesWith(
1966-
Value from, ValueRange to, const TypeConverter *converter) {
1977+
void ConversionPatternRewriterImpl::replaceValueUses(
1978+
Value from, ValueRange to, const TypeConverter *converter,
1979+
function_ref<bool(OpOperand &)> functor) {
1980+
LLVM_DEBUG({
1981+
logger.startLine() << "** Replace Value : '" << from << "'";
1982+
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1983+
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1984+
logger.getOStream() << " (in region of '" << parentOp->getName()
1985+
<< "' (" << parentOp << ")";
1986+
} else {
1987+
logger.getOStream() << " (unlinked block)";
1988+
}
1989+
}
1990+
if (functor) {
1991+
logger.getOStream() << ", conditional replacement";
1992+
}
1993+
});
1994+
19671995
if (!config.allowPatternRollback) {
19681996
SmallVector<Value> toConv = llvm::to_vector(to);
19691997
SmallVector<Value> repls =
@@ -1973,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
19732001
if (!repl)
19742002
return;
19752003

1976-
performReplaceValue(r, from, repl);
2004+
performReplaceValue(r, from, repl, functor);
19772005
return;
19782006
}
19792007

@@ -1992,6 +2020,9 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
19922020
replacedValues.insert(from);
19932021
#endif // NDEBUG
19942022

2023+
if (functor)
2024+
llvm::report_fatal_error(
2025+
"conditional value replacement is not supported in rollback mode");
19952026
mapping.map(from, to);
19962027
appendRewrite<ReplaceValueRewrite>(from, converter);
19972028
}
@@ -2190,18 +2221,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
21902221
}
21912222

21922223
void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2193-
LLVM_DEBUG({
2194-
impl->logger.startLine() << "** Replace Value : '" << from << "'";
2195-
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2196-
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2197-
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2198-
<< "' (" << parentOp << ")\n";
2199-
} else {
2200-
impl->logger.getOStream() << " (unlinked block)\n";
2201-
}
2202-
}
2203-
});
2204-
impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
2224+
impl->replaceValueUses(from, to, impl->currentTypeConverter);
2225+
}
2226+
2227+
void ConversionPatternRewriter::replaceUsesWithIf(
2228+
Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2229+
bool *allUsesReplaced) {
2230+
assert(!allUsesReplaced &&
2231+
"allUsesReplaced is not supported in a dialect conversion");
2232+
impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
22052233
}
22062234

22072235
Value ConversionPatternRewriter::getRemappedValue(Value key) {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: @conditional_replacement(
4+
// CHECK-SAME: %[[arg0:.*]]: i43)
5+
// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42
6+
// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42
7+
// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42
8+
// Uses were replaced for dummy_user_1.
9+
// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> ()
10+
// Uses were also replaced for dummy_user_2, but not by value_replace. The uses
11+
// were replaced due to the block signature conversion.
12+
// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> ()
13+
// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> ()
14+
func.func @conditional_replacement(%arg0: i42) {
15+
%repl = "test.legal_op"() : () -> (i42)
16+
// expected-remark @+1 {{is not legalizable}}
17+
"test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> ()
18+
// expected-remark @+1 {{is not legalizable}}
19+
"test.dummy_user_2"(%arg0) {} : (i42) -> ()
20+
// Perform a conditional 1:N replacement.
21+
"test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> ()
22+
"test.return"() : () -> ()
23+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern {
977977
// Replace the first operand with 2x the second operand.
978978
Value from = op->getOperand(0);
979979
Value repl = op->getOperand(1);
980-
rewriter.replaceAllUsesWith(from, {repl, repl});
980+
if (op->hasAttr("conditional")) {
981+
rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) {
982+
return use.getOwner()->hasAttr("replace_uses");
983+
});
984+
} else {
985+
rewriter.replaceAllUsesWith(from, {repl, repl});
986+
}
981987
rewriter.modifyOpInPlace(op, [&] {
982988
// If the "trigger_rollback" attribute is set, keep the op illegal, so
983989
// that a rollback is triggered.

0 commit comments

Comments
 (0)