@@ -127,6 +127,35 @@ void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
127127 replaceAllUsesWith (from->getResults (), to->getResults ());
128128}
129129
130+ ReplaceOpAction::ReplaceOpAction (ArrayRef<IRUnit> irUnits,
131+ ValueRange replacement)
132+ : Base(irUnits), replacement(replacement) {
133+ assert (irUnits.size () == 1 );
134+ assert (irUnits[0 ]);
135+ assert (isa<Operation *>(irUnits[0 ]));
136+ }
137+
138+ void ReplaceOpAction::print (raw_ostream &os) const {
139+ OpPrintingFlags flags;
140+ flags.elideLargeElementsAttrs (10 );
141+ os << " `" << tag << " ` replacing operation `" ;
142+ getOp ()->print (os, flags);
143+ os << " ` by " ;
144+ bool first = true ;
145+ for (auto r : replacement) {
146+ if (!first)
147+ os << " , " ;
148+ os << " `" ;
149+ r.print (os, flags);
150+ os << " `" ;
151+ first = false ;
152+ }
153+ }
154+
155+ Operation *ReplaceOpAction::getOp () const {
156+ return cast<Operation *>(getContextIRUnits ()[0 ]);
157+ }
158+
130159// / This method replaces the results of the operation with the specified list of
131160// / values. The number of provided values must match the number of results of
132161// / the operation. The replaced op is erased.
@@ -265,8 +294,12 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
265294 bool allReplaced = true ;
266295 for (OpOperand &operand : llvm::make_early_inc_range (from.getUses ())) {
267296 bool replace = functor (operand);
268- if (replace)
297+ if (replace) {
298+ if (auto *fromOp = from.getDefiningOp ())
299+ getContext ()->executeAction <ReplaceOpAction>(
300+ []() {}, ArrayRef<IRUnit>{fromOp}, to);
269301 modifyOpInPlace (operand.getOwner (), [&]() { operand.set (to); });
302+ }
270303 allReplaced &= replace;
271304 }
272305 if (allUsesReplaced)
0 commit comments