Skip to content

Commit 779412b

Browse files
[mlir][Transforms] Deactivate replaceAllUsesWith in dialect conversion
1 parent 673750f commit 779412b

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder {
633633

634634
/// Find uses of `from` and replace them with `to`. Also notify the listener
635635
/// about every in-place op modification (for every use that was replaced).
636-
void replaceAllUsesWith(Value from, Value to) {
636+
virtual void replaceAllUsesWith(Value from, Value to) {
637637
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
638638
Operation *op = operand.getOwner();
639639
modifyOpInPlace(op, [&]() { operand.set(to); });
640640
}
641641
}
642-
void replaceAllUsesWith(Block *from, Block *to) {
642+
virtual void replaceAllUsesWith(Block *from, Block *to) {
643643
for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
644644
Operation *op = operand.getOwner();
645645
modifyOpInPlace(op, [&]() { operand.set(to); });
@@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder {
665665
/// true. Also notify the listener about every in-place op modification (for
666666
/// every use that was replaced). The optional `allUsesReplaced` flag is set
667667
/// to "true" if all uses were replaced.
668-
void replaceUsesWithIf(Value from, Value to,
669-
function_ref<bool(OpOperand &)> functor,
670-
bool *allUsesReplaced = nullptr);
668+
virtual void replaceUsesWithIf(Value from, Value to,
669+
function_ref<bool(OpOperand &)> functor,
670+
bool *allUsesReplaced = nullptr);
671671
void replaceUsesWithIf(ValueRange from, ValueRange to,
672672
function_ref<bool(OpOperand &)> functor,
673673
bool *allUsesReplaced = nullptr);

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
784784
/// function supports both 1:1 and 1:N replacements.
785785
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
786786

787+
/// Replace all the uses of the value `from` with `to`.
788+
/// TODO: Currently not supported in a dialect conversion.
789+
void replaceAllUsesWith(Value from, Value to) override {
790+
llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
791+
}
792+
793+
/// Replace all the uses of the block `from` with `to`.
794+
/// TODO: Currently not supported in a dialect conversion.
795+
void replaceAllUsesWith(Block *from, Block *to) override {
796+
llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
797+
}
798+
799+
/// Replace all the uses of the value `from` with `to` if the `functor`
800+
/// returns "true".
801+
/// TODO: Currently not supported in a dialect conversion.
802+
void replaceUsesWithIf(Value from, Value to,
803+
function_ref<bool(OpOperand &)> functor,
804+
bool *allUsesReplaced = nullptr) override {
805+
llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
806+
}
807+
787808
/// Return the converted value of 'key' with a type defined by the type
788809
/// converter of the currently executing pattern. Return nullptr in the case
789810
/// of failure, the remapped value otherwise.

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include "mlir/Dialect/SCF/IR/SCF.h"
2323
#include "mlir/IR/SymbolTable.h"
2424
#include "mlir/Pass/Pass.h"
25-
#include "mlir/Transforms/DialectConversion.h"
25+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
2626

2727
namespace mlir {
2828
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
538538

539539
/// Applies the conversion patterns in the given function.
540540
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
541-
ConversionTarget target(*module.getContext());
542-
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
543-
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
544-
memref::MemRefDialect>();
545-
546541
RewritePatternSet patterns(module.getContext());
547542
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
548543
FrozenRewritePatternSet frozen(std::move(patterns));
549-
return applyPartialConversion(module, target, frozen);
544+
walkAndApplyPatterns(module, frozen);
545+
auto status = module.walk([](Operation *op) {
546+
if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
547+
return WalkResult::interrupt();
548+
return WalkResult::advance();
549+
});
550+
return failure(status.wasInterrupted());
550551
}
551552

552553
/// A pass converting SCF operations to OpenMP operations.

0 commit comments

Comments
 (0)