@@ -1996,6 +1996,7 @@ class OperationLegalizer {
1996
1996
// / Legalize the resultant IR after successfully applying the given pattern.
1997
1997
LogicalResult legalizePatternResult (Operation *op, const Pattern &pattern,
1998
1998
ConversionPatternRewriter &rewriter,
1999
+ const RewriterState &curState,
1999
2000
const SetVector<Operation *> &newOps,
2000
2001
const SetVector<Operation *> &modifiedOps,
2001
2002
const SetVector<Block *> &insertedBlocks);
@@ -2241,6 +2242,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2241
2242
ConversionPatternRewriter &rewriter) {
2242
2243
auto &rewriterImpl = rewriter.getImpl ();
2243
2244
2245
+ #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2246
+ Operation *checkOp;
2247
+ std::optional<OperationFingerPrint> topLevelFingerPrint;
2248
+ if (!rewriterImpl.config .allowPatternRollback ) {
2249
+ // The op may be getting erased, so we have to check the parent op.
2250
+ // (In rare cases, a pattern may even erase the parent op, which will cause
2251
+ // a crash here. Expensive checks are "best effort".) Skip the check if the
2252
+ // op does not have a parent op.
2253
+ if ((checkOp = op->getParentOp ())) {
2254
+ if (!op->getContext ()->isMultithreadingEnabled ()) {
2255
+ topLevelFingerPrint = OperationFingerPrint (checkOp);
2256
+ } else {
2257
+ // Another thread may be modifying a sibling operation. Therefore, the
2258
+ // fingerprinting mechanism of the parent op works only in
2259
+ // single-threaded mode.
2260
+ LLVM_DEBUG ({
2261
+ rewriterImpl.logger .startLine ()
2262
+ << " WARNING: Multi-threadeding is enabled. Some dialect "
2263
+ " conversion expensive checks are skipped in multithreading "
2264
+ " mode!\n " ;
2265
+ });
2266
+ }
2267
+ }
2268
+ }
2269
+ #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2270
+
2244
2271
// Functor that returns if the given pattern may be applied.
2245
2272
auto canApply = [&](const Pattern &pattern) {
2246
2273
bool canApply = canApplyPattern (op, pattern, rewriter);
@@ -2253,6 +2280,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2253
2280
RewriterState curState = rewriterImpl.getCurrentState ();
2254
2281
auto onFailure = [&](const Pattern &pattern) {
2255
2282
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2283
+ #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2284
+ if (!rewriterImpl.config .allowPatternRollback ) {
2285
+ // Returning "failure" after modifying IR is not allowed.
2286
+ if (checkOp) {
2287
+ OperationFingerPrint fingerPrintAfterPattern (checkOp);
2288
+ if (fingerPrintAfterPattern != *topLevelFingerPrint)
2289
+ llvm::report_fatal_error (" pattern '" + pattern.getDebugName () +
2290
+ " ' returned failure but IR did change" );
2291
+ }
2292
+ }
2293
+ #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2256
2294
rewriterImpl.patternNewOps .clear ();
2257
2295
rewriterImpl.patternModifiedOps .clear ();
2258
2296
rewriterImpl.patternInsertedBlocks .clear ();
@@ -2281,7 +2319,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2281
2319
moveAndReset (rewriterImpl.patternModifiedOps );
2282
2320
SetVector<Block *> insertedBlocks =
2283
2321
moveAndReset (rewriterImpl.patternInsertedBlocks );
2284
- auto result = legalizePatternResult (op, pattern, rewriter, newOps,
2322
+ auto result = legalizePatternResult (op, pattern, rewriter, curState, newOps,
2285
2323
modifiedOps, insertedBlocks);
2286
2324
appliedPatterns.erase (&pattern);
2287
2325
if (failed (result)) {
@@ -2324,7 +2362,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2324
2362
2325
2363
LogicalResult OperationLegalizer::legalizePatternResult (
2326
2364
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2327
- const SetVector<Operation *> &newOps,
2365
+ const RewriterState &curState, const SetVector<Operation *> &newOps,
2328
2366
const SetVector<Operation *> &modifiedOps,
2329
2367
const SetVector<Block *> &insertedBlocks) {
2330
2368
auto &impl = rewriter.getImpl ();
@@ -2340,7 +2378,8 @@ LogicalResult OperationLegalizer::legalizePatternResult(
2340
2378
return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2341
2379
};
2342
2380
if (!replacedRoot () && !updatedRootInPlace ())
2343
- llvm::report_fatal_error (" expected pattern to replace the root operation" );
2381
+ llvm::report_fatal_error (
2382
+ " expected pattern to replace the root operation or modify it in place" );
2344
2383
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2345
2384
2346
2385
// Legalize each of the actions registered during application.
0 commit comments