1717#include " mlir/IR/Operation.h"
1818#include " mlir/Interfaces/FunctionInterfaces.h"
1919#include " mlir/Rewrite/PatternApplicator.h"
20+ #include " llvm/ADT/ScopeExit.h"
2021#include " llvm/ADT/SmallPtrSet.h"
2122#include " llvm/Support/Debug.h"
2223#include " llvm/Support/FormatVariadic.h"
@@ -2216,23 +2217,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22162217 rewriterImpl.logger .startLine () << " * Fold {\n " ;
22172218 rewriterImpl.logger .indent ();
22182219 });
2219- (void )rewriterImpl;
2220+
2221+ // Clear pattern state, so that the next pattern application starts with a
2222+ // clean slate. (The op/block sets are populated by listener notifications.)
2223+ auto cleanup = llvm::make_scope_exit ([&]() {
2224+ rewriterImpl.patternNewOps .clear ();
2225+ rewriterImpl.patternModifiedOps .clear ();
2226+ rewriterImpl.patternInsertedBlocks .clear ();
2227+ });
2228+
2229+ // Upon failure, undo all changes made by the folder.
2230+ RewriterState curState = rewriterImpl.getCurrentState ();
22202231
22212232 // Try to fold the operation.
22222233 StringRef opName = op->getName ().getStringRef ();
22232234 SmallVector<Value, 2 > replacementValues;
22242235 SmallVector<Operation *, 2 > newOps;
22252236 rewriter.setInsertionPoint (op);
2237+ rewriter.startOpModification (op);
22262238 if (failed (rewriter.tryFold (op, replacementValues, &newOps))) {
22272239 LLVM_DEBUG (logFailure (rewriterImpl.logger , " unable to fold" ));
2240+ rewriter.cancelOpModification (op);
22282241 return failure ();
22292242 }
2243+ rewriter.finalizeOpModification (op);
22302244
22312245 // An empty list of replacement values indicates that the fold was in-place.
22322246 // As the operation changed, a new legalization needs to be attempted.
22332247 if (replacementValues.empty ())
22342248 return legalize (op, rewriter);
22352249
2250+ // Insert a replacement for 'op' with the folded replacement values.
2251+ rewriter.replaceOp (op, replacementValues);
2252+
22362253 // Recursively legalize any new constant operations.
22372254 for (Operation *newOp : newOps) {
22382255 if (failed (legalize (newOp, rewriter))) {
@@ -2245,16 +2262,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22452262 " op '" + opName +
22462263 " ' folder rollback of IR modifications requested" );
22472264 }
2248- // Legalization failed: erase all materialized constants.
2249- for (Operation *op : newOps)
2250- rewriter.eraseOp (op);
2265+ rewriterImpl.resetState (
2266+ curState, std::string (op->getName ().getStringRef ()) + " folder" );
22512267 return failure ();
22522268 }
22532269 }
22542270
2255- // Insert a replacement for 'op' with the folded replacement values.
2256- rewriter.replaceOp (op, replacementValues);
2257-
22582271 LLVM_DEBUG (logSuccess (rewriterImpl.logger , " " ));
22592272 return success ();
22602273}
0 commit comments