@@ -1996,6 +1996,7 @@ class OperationLegalizer {
19961996 // / Legalize the resultant IR after successfully applying the given pattern.
19971997 LogicalResult legalizePatternResult (Operation *op, const Pattern &pattern,
19981998 ConversionPatternRewriter &rewriter,
1999+ const RewriterState &curState,
19992000 const SetVector<Operation *> &newOps,
20002001 const SetVector<Operation *> &modifiedOps,
20012002 const SetVector<Block *> &insertedBlocks);
@@ -2241,6 +2242,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22412242 ConversionPatternRewriter &rewriter) {
22422243 auto &rewriterImpl = rewriter.getImpl ();
22432244
2245+ #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2246+ Operation *checkOp;
2247+ std::optional<OperationFingerPrint> topLevelFingerPrint;
2248+ if (!rewriterImpl.config .allowPatternRollback ) {
2249+ // The op may be getting erased, so we have to check the parent op.
2250+ // (In rare cases, a pattern may even erase the parent op, which will cause
2251+ // a crash here. Expensive checks are "best effort".) Skip the check if the
2252+ // op does not have a parent op.
2253+ if ((checkOp = op->getParentOp ())) {
2254+ if (!op->getContext ()->isMultithreadingEnabled ()) {
2255+ topLevelFingerPrint = OperationFingerPrint (checkOp);
2256+ } else {
2257+ // Another thread may be modifying a sibling operation. Therefore, the
2258+ // fingerprinting mechanism of the parent op works only in
2259+ // single-threaded mode.
2260+ LLVM_DEBUG ({
2261+ rewriterImpl.logger .startLine ()
2262+ << " WARNING: Multi-threadeding is enabled. Some dialect "
2263+ " conversion expensive checks are skipped in multithreading "
2264+ " mode!\n " ;
2265+ });
2266+ }
2267+ }
2268+ }
2269+ #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2270+
22442271 // Functor that returns if the given pattern may be applied.
22452272 auto canApply = [&](const Pattern &pattern) {
22462273 bool canApply = canApplyPattern (op, pattern, rewriter);
@@ -2253,6 +2280,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22532280 RewriterState curState = rewriterImpl.getCurrentState ();
22542281 auto onFailure = [&](const Pattern &pattern) {
22552282 assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2283+ #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2284+ if (!rewriterImpl.config .allowPatternRollback ) {
2285+ // Returning "failure" after modifying IR is not allowed.
2286+ if (checkOp) {
2287+ OperationFingerPrint fingerPrintAfterPattern (checkOp);
2288+ if (fingerPrintAfterPattern != *topLevelFingerPrint)
2289+ llvm::report_fatal_error (" pattern '" + pattern.getDebugName () +
2290+ " ' returned failure but IR did change" );
2291+ }
2292+ }
2293+ #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
22562294 rewriterImpl.patternNewOps .clear ();
22572295 rewriterImpl.patternModifiedOps .clear ();
22582296 rewriterImpl.patternInsertedBlocks .clear ();
@@ -2281,7 +2319,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22812319 moveAndReset (rewriterImpl.patternModifiedOps );
22822320 SetVector<Block *> insertedBlocks =
22832321 moveAndReset (rewriterImpl.patternInsertedBlocks );
2284- auto result = legalizePatternResult (op, pattern, rewriter, newOps,
2322+ auto result = legalizePatternResult (op, pattern, rewriter, curState, newOps,
22852323 modifiedOps, insertedBlocks);
22862324 appliedPatterns.erase (&pattern);
22872325 if (failed (result)) {
@@ -2324,7 +2362,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
23242362
23252363LogicalResult OperationLegalizer::legalizePatternResult (
23262364 Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2327- const SetVector<Operation *> &newOps,
2365+ const RewriterState &curState, const SetVector<Operation *> &newOps,
23282366 const SetVector<Operation *> &modifiedOps,
23292367 const SetVector<Block *> &insertedBlocks) {
23302368 auto &impl = rewriter.getImpl ();
@@ -2340,7 +2378,8 @@ LogicalResult OperationLegalizer::legalizePatternResult(
23402378 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
23412379 };
23422380 if (!replacedRoot () && !updatedRootInPlace ())
2343- llvm::report_fatal_error (" expected pattern to replace the root operation" );
2381+ llvm::report_fatal_error (
2382+ " expected pattern to replace the root operation or modify it in place" );
23442383#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
23452384
23462385 // Legalize each of the actions registered during application.
0 commit comments