diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index f40405773ee8..51fb620bc4f0 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/Support/TypeName.h" +#include #include using llvm::SmallPtrSetImpl; @@ -652,6 +653,9 @@ class RewriterBase : public OpBuilder { /// Find uses of `from` and replace them with `to`. Also notify the listener /// about every in-place op modification (for every use that was replaced). void replaceAllUsesWith(Value from, Value to) { + if (auto *fromOp = from.getDefiningOp()) + getContext()->executeAction( + []() {}, ArrayRef{fromOp}, to); for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); diff --git a/mlir/include/mlir/IR/PatternMatchAction.h b/mlir/include/mlir/IR/PatternMatchAction.h new file mode 100644 index 000000000000..e49d2b9df520 --- /dev/null +++ b/mlir/include/mlir/IR/PatternMatchAction.h @@ -0,0 +1,20 @@ +#ifndef MLIR_IR_PATTERNMATCHACTION_H +#define MLIR_IR_PATTERNMATCHACTION_H + +#include "mlir/IR/Action.h" + +namespace mlir { +struct ReplaceOpAction : public tracing::ActionImpl { + using Base = tracing::ActionImpl; + ReplaceOpAction(ArrayRef irUnits, ValueRange replacement); + static constexpr StringLiteral tag = "op-replacement"; + void print(raw_ostream &os) const override; + + Operation *getOp() const; + +public: + ValueRange replacement; +}; +} // namespace mlir + +#endif diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 286f47ce6913..4c352b35d2b8 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -127,6 +127,35 @@ void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) { replaceAllUsesWith(from->getResults(), to->getResults()); } +ReplaceOpAction::ReplaceOpAction(ArrayRef irUnits, + ValueRange replacement) + : Base(irUnits), replacement(replacement) { + assert(irUnits.size() == 1); + assert(irUnits[0]); + assert(isa(irUnits[0])); +} + +void ReplaceOpAction::print(raw_ostream &os) const { + OpPrintingFlags flags; + flags.elideLargeElementsAttrs(10); + os << "`" << tag << "` replacing operation `"; + getOp()->print(os, flags); + os << "` by "; + bool first = true; + for (auto r : replacement) { + if (!first) + os << ", "; + os << "`"; + r.print(os, flags); + os << "`"; + first = false; + } +} + +Operation *ReplaceOpAction::getOp() const { + return cast(getContextIRUnits()[0]); +} + /// This method replaces the results of the operation with the specified list of /// values. The number of provided values must match the number of results of /// the operation. The replaced op is erased. @@ -265,8 +294,12 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to, bool allReplaced = true; for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { bool replace = functor(operand); - if (replace) + if (replace) { + if (auto *fromOp = from.getDefiningOp()) + getContext()->executeAction( + []() {}, ArrayRef{fromOp}, to); modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); }); + } allReplaced &= replace; } if (allUsesReplaced)