Skip to content

Commit 08e101e

Browse files
[mlir][Transforms] Dialect Conversion: check for "failure" after modification (#150748)
Add a new "expensive check" when running with `allowPatternRollback = false`: returning "failure" after modifying IR is no longer allowed. This check detects a few more API violations in addition to the check `undoRewrites`. The latter check will be removed soon. (Because the One-Shot Dialect Conversion will no longer maintain the stack of IR rewrites.) Also fix a build error when expensive checks are enabled.
1 parent 931228e commit 08e101e

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,7 @@ class OperationLegalizer {
19961996
/// Legalize the resultant IR after successfully applying the given pattern.
19971997
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
19981998
ConversionPatternRewriter &rewriter,
1999+
const RewriterState &curState,
19992000
const SetVector<Operation *> &newOps,
20002001
const SetVector<Operation *> &modifiedOps,
20012002
const SetVector<Block *> &insertedBlocks);
@@ -2241,6 +2242,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22412242
ConversionPatternRewriter &rewriter) {
22422243
auto &rewriterImpl = rewriter.getImpl();
22432244

2245+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2246+
Operation *checkOp;
2247+
std::optional<OperationFingerPrint> topLevelFingerPrint;
2248+
if (!rewriterImpl.config.allowPatternRollback) {
2249+
// The op may be getting erased, so we have to check the parent op.
2250+
// (In rare cases, a pattern may even erase the parent op, which will cause
2251+
// a crash here. Expensive checks are "best effort".) Skip the check if the
2252+
// op does not have a parent op.
2253+
if ((checkOp = op->getParentOp())) {
2254+
if (!op->getContext()->isMultithreadingEnabled()) {
2255+
topLevelFingerPrint = OperationFingerPrint(checkOp);
2256+
} else {
2257+
// Another thread may be modifying a sibling operation. Therefore, the
2258+
// fingerprinting mechanism of the parent op works only in
2259+
// single-threaded mode.
2260+
LLVM_DEBUG({
2261+
rewriterImpl.logger.startLine()
2262+
<< "WARNING: Multi-threadeding is enabled. Some dialect "
2263+
"conversion expensive checks are skipped in multithreading "
2264+
"mode!\n";
2265+
});
2266+
}
2267+
}
2268+
}
2269+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2270+
22442271
// Functor that returns if the given pattern may be applied.
22452272
auto canApply = [&](const Pattern &pattern) {
22462273
bool canApply = canApplyPattern(op, pattern, rewriter);
@@ -2253,6 +2280,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22532280
RewriterState curState = rewriterImpl.getCurrentState();
22542281
auto onFailure = [&](const Pattern &pattern) {
22552282
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2283+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2284+
if (!rewriterImpl.config.allowPatternRollback) {
2285+
// Returning "failure" after modifying IR is not allowed.
2286+
if (checkOp) {
2287+
OperationFingerPrint fingerPrintAfterPattern(checkOp);
2288+
if (fingerPrintAfterPattern != *topLevelFingerPrint)
2289+
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2290+
"' returned failure but IR did change");
2291+
}
2292+
}
2293+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
22562294
rewriterImpl.patternNewOps.clear();
22572295
rewriterImpl.patternModifiedOps.clear();
22582296
rewriterImpl.patternInsertedBlocks.clear();
@@ -2281,7 +2319,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22812319
moveAndReset(rewriterImpl.patternModifiedOps);
22822320
SetVector<Block *> insertedBlocks =
22832321
moveAndReset(rewriterImpl.patternInsertedBlocks);
2284-
auto result = legalizePatternResult(op, pattern, rewriter, newOps,
2322+
auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
22852323
modifiedOps, insertedBlocks);
22862324
appliedPatterns.erase(&pattern);
22872325
if (failed(result)) {
@@ -2324,7 +2362,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
23242362

23252363
LogicalResult OperationLegalizer::legalizePatternResult(
23262364
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2327-
const SetVector<Operation *> &newOps,
2365+
const RewriterState &curState, const SetVector<Operation *> &newOps,
23282366
const SetVector<Operation *> &modifiedOps,
23292367
const SetVector<Block *> &insertedBlocks) {
23302368
auto &impl = rewriter.getImpl();
@@ -2340,7 +2378,8 @@ LogicalResult OperationLegalizer::legalizePatternResult(
23402378
return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
23412379
};
23422380
if (!replacedRoot() && !updatedRootInPlace())
2343-
llvm::report_fatal_error("expected pattern to replace the root operation");
2381+
llvm::report_fatal_error(
2382+
"expected pattern to replace the root operation or modify it in place");
23442383
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
23452384

23462385
// Legalize each of the actions registered during application.

0 commit comments

Comments
 (0)