diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index ea43f8a147d47..91cc547e7a227 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -23,10 +23,12 @@ using namespace mlir::detail; PatternApplicator::PatternApplicator( const FrozenRewritePatternSet &frozenPatternList) : frozenPatternList(frozenPatternList) { +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { mutableByteCodeState = std::make_unique(); bytecode->initializeMutableState(*mutableByteCodeState); } +#endif } PatternApplicator::~PatternApplicator() = default; @@ -54,12 +56,14 @@ static void logSucessfulPatternApplication(Operation *op) { #endif void PatternApplicator::applyCostModel(CostModel model) { - // Apply the cost model to the bytecode patterns first, and then the native - // patterns. +// Apply the cost model to the bytecode patterns first, and then the native +// patterns. +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { for (const auto &it : llvm::enumerate(bytecode->getPatterns())) mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); } +#endif // Copy over the patterns so that we can sort by benefit based on the cost // model. Patterns that are already impossible to match are ignored. @@ -121,10 +125,12 @@ void PatternApplicator::walkAllPatterns( walk(*pattern); for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) walk(it); +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { for (const Pattern &it : bytecode->getPatterns()) walk(it); } +#endif } LogicalResult PatternApplicator::matchAndRewrite( @@ -132,13 +138,15 @@ LogicalResult PatternApplicator::matchAndRewrite( function_ref canApply, function_ref onFailure, function_ref onSuccess) { - // Before checking native patterns, first match against the bytecode. This - // won't automatically perform any rewrites so there is no need to worry about - // conflicts. +// Before checking native patterns, first match against the bytecode. This +// won't automatically perform any rewrites so there is no need to worry about +// conflicts. +#ifdef MLIR_ENABLE_PDL_IN_PATTERNMATCH SmallVector pdlMatches; const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); if (bytecode) bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); +#endif // Check to see if there are patterns matching this specific operation type. MutableArrayRef opPatterns; @@ -150,7 +158,9 @@ LogicalResult PatternApplicator::matchAndRewrite( // operation type in an interleaved fashion. unsigned opIt = 0, opE = opPatterns.size(); unsigned anyIt = 0, anyE = anyOpPatterns.size(); +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH unsigned pdlIt = 0, pdlE = pdlMatches.size(); +#endif LogicalResult result = failure(); do { // Find the next pattern with the highest benefit. @@ -168,6 +178,7 @@ LogicalResult PatternApplicator::matchAndRewrite( bestPattern = anyOpPatterns[anyIt]; } +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH const PDLByteCode::MatchResult *pdlMatch = nullptr; /// PDL patterns. if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < @@ -176,6 +187,7 @@ LogicalResult PatternApplicator::matchAndRewrite( pdlMatch = &pdlMatches[pdlIt]; bestPattern = pdlMatch->pattern; } +#endif if (!bestPattern) break; @@ -200,10 +212,13 @@ LogicalResult PatternApplicator::matchAndRewrite( // pattern. Operation *dumpRootOp = getDumpRootOp(op); #endif +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH if (pdlMatch) { result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); - } else { + } else +#endif + { LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" << bestPattern->getDebugName() << "\"\n"); @@ -234,7 +249,9 @@ LogicalResult PatternApplicator::matchAndRewrite( break; } while (true); +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH if (mutableByteCodeState) mutableByteCodeState->cleanupAfterMatchAndRewrite(); +#endif return result; }