Skip to content

Commit c639475

Browse files
[mlir][Transforms] Dialect Conversion: Fix folder implementation (#150775)
Operation folders can do two things: 1. Modify IR (in-place op modification). Failing to legalize an in-place folded operation does not trigger an immediate rollback. This happens only if the driver decides to try a different lowering path, requiring it to roll back a bunch of modifications, including the application of the folder. 2. Create new IR (constant op materialization of a folded attribute). Failing to legalize a newly created constant op triggers an immediate rollback. In-place op modifications should be guarded by `startOpModification`/`finalizeOpModification` because they are no different from other in-place op modifications. (They just happen outside of a pattern, but that does not mean that we should not track those changes; we are tracking everything else.) This commit adds those two function calls. This commit also moves the `rewriter.replaceOp(op, replacementValues);` function call before the loop nest that legalizes the newly created constant ops (and therefore `replacementValues`). Conceptually, the folded op must be replaced before attempting to legalize the constants because the constant ops may themselves be replaced as part of their own legalization process. The previous implementation happened to work in the current conversion driver, but is incompatible with the One-Shot Dialect Conversion driver, which expects to see the most recent IR at all time. From an end-user perspective, this commit should be NFC. A common folder-rollback pattern that is exercised by multiple tests cases: A `memref.dim` is folded to `arith.constant`, but `arith.constant` is not marked as legal as per the conversion target, triggering a rollback. Note: Folding is generally unsafe in a dialect conversion (see #92683), but that's a different issue. (In a One-Shot Dialect Conversion, it will no longer be unsafe.)
1 parent cf1abe6 commit c639475

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Interfaces/FunctionInterfaces.h"
1919
#include "mlir/Rewrite/PatternApplicator.h"
20+
#include "llvm/ADT/ScopeExit.h"
2021
#include "llvm/ADT/SmallPtrSet.h"
2122
#include "llvm/Support/Debug.h"
2223
#include "llvm/Support/FormatVariadic.h"
@@ -2240,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22402241
rewriterImpl.logger.startLine() << "* Fold {\n";
22412242
rewriterImpl.logger.indent();
22422243
});
2243-
(void)rewriterImpl;
2244+
2245+
// Clear pattern state, so that the next pattern application starts with a
2246+
// clean slate. (The op/block sets are populated by listener notifications.)
2247+
auto cleanup = llvm::make_scope_exit([&]() {
2248+
rewriterImpl.patternNewOps.clear();
2249+
rewriterImpl.patternModifiedOps.clear();
2250+
rewriterImpl.patternInsertedBlocks.clear();
2251+
});
2252+
2253+
// Upon failure, undo all changes made by the folder.
2254+
RewriterState curState = rewriterImpl.getCurrentState();
22442255

22452256
// Try to fold the operation.
22462257
StringRef opName = op->getName().getStringRef();
22472258
SmallVector<Value, 2> replacementValues;
22482259
SmallVector<Operation *, 2> newOps;
22492260
rewriter.setInsertionPoint(op);
2261+
rewriter.startOpModification(op);
22502262
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
22512263
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2264+
rewriter.cancelOpModification(op);
22522265
return failure();
22532266
}
2267+
rewriter.finalizeOpModification(op);
22542268

22552269
// An empty list of replacement values indicates that the fold was in-place.
22562270
// As the operation changed, a new legalization needs to be attempted.
22572271
if (replacementValues.empty())
22582272
return legalize(op, rewriter);
22592273

2274+
// Insert a replacement for 'op' with the folded replacement values.
2275+
rewriter.replaceOp(op, replacementValues);
2276+
22602277
// Recursively legalize any new constant operations.
22612278
for (Operation *newOp : newOps) {
22622279
if (failed(legalize(newOp, rewriter))) {
@@ -2269,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
22692286
"op '" + opName +
22702287
"' folder rollback of IR modifications requested");
22712288
}
2272-
// Legalization failed: erase all materialized constants.
2273-
for (Operation *op : newOps)
2274-
rewriter.eraseOp(op);
2289+
rewriterImpl.resetState(
2290+
curState, std::string(op->getName().getStringRef()) + " folder");
22752291
return failure();
22762292
}
22772293
}
22782294

2279-
// Insert a replacement for 'op' with the folded replacement values.
2280-
rewriter.replaceOp(op, replacementValues);
2281-
22822295
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
22832296
return success();
22842297
}

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
104104
"test.signature_conversion_no_converter"() ({
105105
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
106106
^bb0(%arg0: f32):
107-
"test.type_consumer"(%arg0) : (f32) -> ()
108107
// expected-note@below{{see existing live user here}}
108+
"test.type_consumer"(%arg0) : (f32) -> ()
109109
"test.return"(%arg0) : (f32) -> ()
110110
}) : () -> ()
111111
return

0 commit comments

Comments
 (0)