1717#include " mlir/IR/Operation.h"
1818#include " mlir/Interfaces/FunctionInterfaces.h"
1919#include " mlir/Rewrite/PatternApplicator.h"
20+ #include " llvm/ADT/ScopeExit.h"
2021#include " llvm/ADT/SmallPtrSet.h"
2122#include " llvm/Support/Debug.h"
2223#include " llvm/Support/FormatVariadic.h"
@@ -1759,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
17591760 impl->logger .startLine ()
17601761 << " ** Replace : '" << op->getName () << " '(" << op << " )\n " ;
17611762 });
1763+
1764+ // If the current insertion point is before the erased operation, we adjust
1765+ // the insertion point to be after the operation.
1766+ if (getInsertionPoint () == op->getIterator ())
1767+ setInsertionPointAfter (op);
1768+
17621769 SmallVector<SmallVector<Value>> newVals =
17631770 llvm::map_to_vector (newValues, [](Value v) -> SmallVector<Value> {
17641771 return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1774,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
17741781 impl->logger .startLine ()
17751782 << " ** Replace : '" << op->getName () << " '(" << op << " )\n " ;
17761783 });
1784+
1785+ // If the current insertion point is before the erased operation, we adjust
1786+ // the insertion point to be after the operation.
1787+ if (getInsertionPoint () == op->getIterator ())
1788+ setInsertionPointAfter (op);
1789+
17771790 impl->replaceOp (op, std::move (newValues));
17781791}
17791792
@@ -1782,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
17821795 impl->logger .startLine ()
17831796 << " ** Erase : '" << op->getName () << " '(" << op << " )\n " ;
17841797 });
1798+
1799+ // If the current insertion point is before the erased operation, we adjust
1800+ // the insertion point to be after the operation.
1801+ if (getInsertionPoint () == op->getIterator ())
1802+ setInsertionPointAfter (op);
1803+
17851804 SmallVector<SmallVector<Value>> nullRepls (op->getNumResults (), {});
17861805 impl->replaceOp (op, std::move (nullRepls));
17871806}
@@ -1888,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
18881907 moveOpBefore (&source->front (), dest, before);
18891908 }
18901909
1910+ // If the current insertion point is within the source block, adjust the
1911+ // insertion point to the destination block.
1912+ if (getInsertionBlock () == source)
1913+ setInsertionPoint (dest, getInsertionPoint ());
1914+
18911915 // Erase the source block.
18921916 eraseBlock (source);
18931917}
@@ -2217,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22172241 rewriterImpl.logger .startLine () << " * Fold {\n " ;
22182242 rewriterImpl.logger .indent ();
22192243 });
2220- (void )rewriterImpl;
2244+
2245+ // Clear pattern state, so that the next pattern application starts with a
2246+ // clean slate. (The op/block sets are populated by listener notifications.)
2247+ auto cleanup = llvm::make_scope_exit ([&]() {
2248+ rewriterImpl.patternNewOps .clear ();
2249+ rewriterImpl.patternModifiedOps .clear ();
2250+ rewriterImpl.patternInsertedBlocks .clear ();
2251+ });
2252+
2253+ // Upon failure, undo all changes made by the folder.
2254+ RewriterState curState = rewriterImpl.getCurrentState ();
22212255
22222256 // Try to fold the operation.
22232257 StringRef opName = op->getName ().getStringRef ();
22242258 SmallVector<Value, 2 > replacementValues;
22252259 SmallVector<Operation *, 2 > newOps;
22262260 rewriter.setInsertionPoint (op);
2261+ rewriter.startOpModification (op);
22272262 if (failed (rewriter.tryFold (op, replacementValues, &newOps))) {
22282263 LLVM_DEBUG (logFailure (rewriterImpl.logger , " unable to fold" ));
2264+ rewriter.cancelOpModification (op);
22292265 return failure ();
22302266 }
2267+ rewriter.finalizeOpModification (op);
22312268
22322269 // An empty list of replacement values indicates that the fold was in-place.
22332270 // As the operation changed, a new legalization needs to be attempted.
22342271 if (replacementValues.empty ())
22352272 return legalize (op, rewriter);
22362273
2274+ // Insert a replacement for 'op' with the folded replacement values.
2275+ rewriter.replaceOp (op, replacementValues);
2276+
22372277 // Recursively legalize any new constant operations.
22382278 for (Operation *newOp : newOps) {
22392279 if (failed (legalize (newOp, rewriter))) {
@@ -2246,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22462286 " op '" + opName +
22472287 " ' folder rollback of IR modifications requested" );
22482288 }
2249- // Legalization failed: erase all materialized constants.
2250- for (Operation *op : newOps)
2251- rewriter.eraseOp (op);
2289+ rewriterImpl.resetState (
2290+ curState, std::string (op->getName ().getStringRef ()) + " folder" );
22522291 return failure ();
22532292 }
22542293 }
22552294
2256- // Insert a replacement for 'op' with the folded replacement values.
2257- rewriter.replaceOp (op, replacementValues);
2258-
22592295 LLVM_DEBUG (logSuccess (rewriterImpl.logger , " " ));
22602296 return success ();
22612297}
0 commit comments