Skip to content

Commit 2929a29

Browse files
[mlir][Transforms] Add support for ConversionPatternRewriter::replaceAllUsesWith (#155244)
This commit generalizes `replaceUsesOfBlockArgument` to `replaceAllUsesWith`. In rollback mode, the same restrictions keep applying: a value cannot be replaced multiple times and a call to `replaceAllUsesWith` will replace all current and future uses of the `from` value. `replaceAllUsesWith` is now fully supported and its behavior is consistent with the remaining dialect conversion API. Before this commit, `replaceAllUsesWith` was immediately reflected in the IR when running in rollback mode. After this commit, `replaceAllUsesWith` changes are materialized in a delayed fashion, at the end of the dialect conversion. This is consistent with the `replaceUsesOfBlockArgument` and `replaceOp` APIs. `replaceAllUsesExcept` etc. are still not supported and will be deactivated on the `ConversionPatternRewriter` (when running in rollback mode) in a follow-up commit. Note for LLVM integration: Replace `replaceUsesOfBlockArgument` with `replaceAllUsesWith`. If you are seeing failures, you may have patterns that use `replaceAllUsesWith` incorrectly (e.g., being called multiple times on the same value) or bypass the rewriter API entirely. E.g., such failures were mitigated in Flang by switching to the walk-patterns driver (#156171). You can temporarily reactivate the old behavior by calling `RewriterBase::replaceAllUsesWith`. However, note that that behavior is faulty in a dialect conversion. E.g., the base `RewriterBase::replaceAllUsesWith` implementation does not see uses of the `from` value that have not materialized yet and will, therefore, not replace them.
1 parent 8dda18f commit 2929a29

File tree

8 files changed

+194
-91
lines changed

8 files changed

+194
-91
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,18 @@ class DoConcurrentConversion
444444
mlir::SymbolTable &moduleSymbolTable;
445445
};
446446

447+
/// A listener that forwards notifyOperationErased to the given callback.
448+
struct CallbackListener : public mlir::RewriterBase::Listener {
449+
CallbackListener(std::function<void(mlir::Operation *op)> onOperationErased)
450+
: onOperationErased(onOperationErased) {}
451+
452+
void notifyOperationErased(mlir::Operation *op) override {
453+
onOperationErased(op);
454+
}
455+
456+
std::function<void(mlir::Operation *op)> onOperationErased;
457+
};
458+
447459
class DoConcurrentConversionPass
448460
: public flangomp::impl::DoConcurrentConversionPassBase<
449461
DoConcurrentConversionPass> {
@@ -468,6 +480,10 @@ class DoConcurrentConversionPass
468480
}
469481

470482
llvm::DenseSet<fir::DoConcurrentOp> concurrentLoopsToSkip;
483+
CallbackListener callbackListener([&](mlir::Operation *op) {
484+
if (auto loop = mlir::dyn_cast<fir::DoConcurrentOp>(op))
485+
concurrentLoopsToSkip.erase(loop);
486+
});
471487
mlir::RewritePatternSet patterns(context);
472488
patterns.insert<DoConcurrentConversion>(
473489
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
@@ -480,8 +496,11 @@ class DoConcurrentConversionPass
480496
target.markUnknownOpDynamicallyLegal(
481497
[](mlir::Operation *) { return true; });
482498

483-
if (mlir::failed(
484-
mlir::applyFullConversion(module, target, std::move(patterns)))) {
499+
mlir::ConversionConfig config;
500+
config.allowPatternRollback = false;
501+
config.listener = &callbackListener;
502+
if (mlir::failed(mlir::applyFullConversion(module, target,
503+
std::move(patterns), config))) {
485504
signalPassFailure();
486505
}
487506
}

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
633633

634634
/// Find uses of `from` and replace them with `to`. Also notify the listener
635635
/// about every in-place op modification (for every use that was replaced).
636-
void replaceAllUsesWith(Value from, Value to) {
636+
virtual void replaceAllUsesWith(Value from, Value to) {
637637
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
638638
Operation *op = operand.getOwner();
639639
modifyOpInPlace(op, [&]() { operand.set(to); });

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -854,15 +854,29 @@ class ConversionPatternRewriter final : public PatternRewriter {
854854
Region *region, const TypeConverter &converter,
855855
TypeConverter::SignatureConversion *entryConversion = nullptr);
856856

857-
/// Replace all the uses of the block argument `from` with `to`. This
858-
/// function supports both 1:1 and 1:N replacements.
857+
/// Replace all the uses of `from` with `to`. The type of `from` and `to` is
858+
/// allowed to differ. The conversion driver will try to reconcile all type
859+
/// mismatches that still exist at the end of the conversion with
860+
/// materializations. This function supports both 1:1 and 1:N replacements.
859861
///
860-
/// Note: If `allowPatternRollback` is set to "true", this function replaces
861-
/// all current and future uses of the block argument. This same block
862-
/// block argument must not be replaced multiple times. Uses are not replaced
863-
/// immediately but in a delayed fashion. Patterns may still see the original
864-
/// uses when inspecting IR.
865-
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
862+
/// Note: If `allowPatternRollback` is set to "true", this function behaves
863+
/// slightly different:
864+
///
865+
/// 1. All current and future uses of `from` are replaced. The same value must
866+
/// not be replaced multiple times. That's an API violation.
867+
/// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
868+
/// may still see the original uses when inspecting IR.
869+
/// 3. Uses within the same block that appear before the defining operation
870+
/// of the replacement value are not replaced. This allows users to
871+
/// perform certain replaceAllUsesExcept-style replacements, even though
872+
/// such API is not directly supported.
873+
///
874+
/// Note: In an attempt to align the ConversionPatternRewriter and
875+
/// RewriterBase APIs, (3) may be removed in the future.
876+
void replaceAllUsesWith(Value from, ValueRange to);
877+
void replaceAllUsesWith(Value from, Value to) override {
878+
replaceAllUsesWith(from, ValueRange{to});
879+
}
866880

867881
/// Return the converted value of 'key' with a type defined by the type
868882
/// converter of the currently executing pattern. Return nullptr in the case

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
284284
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
285285

286286
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
287-
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
287+
rewriter.replaceAllUsesWith(arg, valueArg);
288288
}
289289
}
290290

0 commit comments

Comments
 (0)