Skip to content

Commit a84528d

Browse files
committed
MLIR: Add ReplaceOpAction into DialectConversion and PatternMatch
1 parent 0faeea4 commit a84528d

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/BuiltinOps.h"
1414
#include "llvm/ADT/FunctionExtras.h"
1515
#include "llvm/Support/TypeName.h"
16+
#include <mlir/IR/PatternMatchAction.h>
1617
#include <optional>
1718

1819
using llvm::SmallPtrSetImpl;
@@ -652,6 +653,8 @@ class RewriterBase : public OpBuilder {
652653
/// Find uses of `from` and replace them with `to`. Also notify the listener
653654
/// about every in-place op modification (for every use that was replaced).
654655
void replaceAllUsesWith(Value from, Value to) {
656+
if(auto* fromOp = from.getDefiningOp())
657+
getContext()->executeAction<ReplaceOpAction>([]() {}, ArrayRef<IRUnit>{fromOp}, to);
655658
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
656659
Operation *op = operand.getOwner();
657660
modifyOpInPlace(op, [&]() { operand.set(to); });
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef MLIR_IR_PATTERNMATCHACTION_H
2+
#define MLIR_IR_PATTERNMATCHACTION_H
3+
4+
#include "mlir/IR/Action.h"
5+
6+
namespace mlir {
7+
struct ReplaceOpAction : public tracing::ActionImpl<ReplaceOpAction> {
8+
using Base = tracing::ActionImpl<ReplaceOpAction>;
9+
ReplaceOpAction(ArrayRef<IRUnit> irUnits, ValueRange replacement);
10+
static constexpr StringLiteral tag = "op-replacement";
11+
void print(raw_ostream &os) const override;
12+
13+
Operation *getOp() const;
14+
15+
public:
16+
ValueRange replacement;
17+
};
18+
}
19+
20+
#endif

mlir/lib/IR/PatternMatch.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,35 @@ void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
127127
replaceAllUsesWith(from->getResults(), to->getResults());
128128
}
129129

130+
ReplaceOpAction::ReplaceOpAction(ArrayRef<IRUnit> irUnits,
131+
ValueRange replacement)
132+
: Base(irUnits), replacement(replacement) {
133+
assert(irUnits.size() == 1);
134+
assert(irUnits[0]);
135+
assert(isa<Operation *>(irUnits[0]));
136+
}
137+
138+
void ReplaceOpAction::print(raw_ostream &os) const {
139+
OpPrintingFlags flags;
140+
flags.elideLargeElementsAttrs(10);
141+
os << "`" << tag << "` replacing operation `";
142+
getOp()->print(os, flags);
143+
os << "` by ";
144+
bool first = true;
145+
for (auto r : replacement) {
146+
if (!first)
147+
os << ", ";
148+
os << "`";
149+
r.print(os, flags);
150+
os << "`";
151+
first = false;
152+
}
153+
}
154+
155+
Operation *ReplaceOpAction::getOp() const {
156+
return cast<Operation *>(getContextIRUnits()[0]);
157+
}
158+
130159
/// This method replaces the results of the operation with the specified list of
131160
/// values. The number of provided values must match the number of results of
132161
/// the operation. The replaced op is erased.
@@ -265,8 +294,12 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
265294
bool allReplaced = true;
266295
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
267296
bool replace = functor(operand);
268-
if (replace)
297+
if (replace) {
298+
if (auto *fromOp = from.getDefiningOp())
299+
getContext()->executeAction<ReplaceOpAction>(
300+
[]() {}, ArrayRef<IRUnit>{fromOp}, to);
269301
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
302+
}
270303
allReplaced &= replace;
271304
}
272305
if (allUsesReplaced)

0 commit comments

Comments
 (0)