Skip to content

Commit 3024f0a

Browse files
[mlir][Transforms] Dialect Conversion: Fix folder rollback
1 parent 3aeab92 commit 3024f0a

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Interfaces/FunctionInterfaces.h"
1919
#include "mlir/Rewrite/PatternApplicator.h"
20+
#include "llvm/ADT/ScopeExit.h"
2021
#include "llvm/ADT/SmallPtrSet.h"
2122
#include "llvm/Support/Debug.h"
2223
#include "llvm/Support/FormatVariadic.h"
@@ -2216,23 +2217,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22162217
rewriterImpl.logger.startLine() << "* Fold {\n";
22172218
rewriterImpl.logger.indent();
22182219
});
2219-
(void)rewriterImpl;
2220+
2221+
// Clear pattern state, so that the next pattern application starts with a
2222+
// clean slate. (The op/block sets are populated by listener notifications.)
2223+
auto cleanup = llvm::make_scope_exit([&]() {
2224+
rewriterImpl.patternNewOps.clear();
2225+
rewriterImpl.patternModifiedOps.clear();
2226+
rewriterImpl.patternInsertedBlocks.clear();
2227+
});
2228+
2229+
// Upon failure, undo all changes made by the folder.
2230+
RewriterState curState = rewriterImpl.getCurrentState();
22202231

22212232
// Try to fold the operation.
22222233
StringRef opName = op->getName().getStringRef();
22232234
SmallVector<Value, 2> replacementValues;
22242235
SmallVector<Operation *, 2> newOps;
22252236
rewriter.setInsertionPoint(op);
2237+
rewriter.startOpModification(op);
22262238
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
22272239
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2240+
rewriter.cancelOpModification(op);
22282241
return failure();
22292242
}
2243+
rewriter.finalizeOpModification(op);
22302244

22312245
// An empty list of replacement values indicates that the fold was in-place.
22322246
// As the operation changed, a new legalization needs to be attempted.
22332247
if (replacementValues.empty())
22342248
return legalize(op, rewriter);
22352249

2250+
// Insert a replacement for 'op' with the folded replacement values.
2251+
rewriter.replaceOp(op, replacementValues);
2252+
22362253
// Recursively legalize any new constant operations.
22372254
for (Operation *newOp : newOps) {
22382255
if (failed(legalize(newOp, rewriter))) {
@@ -2245,16 +2262,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22452262
"op '" + opName +
22462263
"' folder rollback of IR modifications requested");
22472264
}
2248-
// Legalization failed: erase all materialized constants.
2249-
for (Operation *op : newOps)
2250-
rewriter.eraseOp(op);
2265+
rewriterImpl.resetState(
2266+
curState, std::string(op->getName().getStringRef()) + " folder");
22512267
return failure();
22522268
}
22532269
}
22542270

2255-
// Insert a replacement for 'op' with the folded replacement values.
2256-
rewriter.replaceOp(op, replacementValues);
2257-
22582271
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
22592272
return success();
22602273
}

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
104104
"test.signature_conversion_no_converter"() ({
105105
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
106106
^bb0(%arg0: f32):
107-
"test.type_consumer"(%arg0) : (f32) -> ()
108107
// expected-note@below{{see existing live user here}}
108+
"test.type_consumer"(%arg0) : (f32) -> ()
109109
"test.return"(%arg0) : (f32) -> ()
110110
}) : () -> ()
111111
return

0 commit comments

Comments
 (0)