Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions mlir/lib/Rewrite/PatternApplicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ using namespace mlir::detail;
PatternApplicator::PatternApplicator(
const FrozenRewritePatternSet &frozenPatternList)
: frozenPatternList(frozenPatternList) {
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
bytecode->initializeMutableState(*mutableByteCodeState);
}
#endif
}
PatternApplicator::~PatternApplicator() = default;

Expand Down Expand Up @@ -54,12 +56,14 @@ static void logSucessfulPatternApplication(Operation *op) {
#endif

void PatternApplicator::applyCostModel(CostModel model) {
// Apply the cost model to the bytecode patterns first, and then the native
// patterns.
// Apply the cost model to the bytecode patterns first, and then the native
// patterns.
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
}
#endif

// Copy over the patterns so that we can sort by benefit based on the cost
// model. Patterns that are already impossible to match are ignored.
Expand Down Expand Up @@ -121,24 +125,28 @@ void PatternApplicator::walkAllPatterns(
walk(*pattern);
for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
walk(it);
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
for (const Pattern &it : bytecode->getPatterns())
walk(it);
}
#endif
}

LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Before checking native patterns, first match against the bytecode. This
// won't automatically perform any rewrites so there is no need to worry about
// conflicts.
// Before checking native patterns, first match against the bytecode. This
// won't automatically perform any rewrites so there is no need to worry about
// conflicts.
#ifdef MLIR_ENABLE_PDL_IN_PATTERNMATCH
SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
if (bytecode)
bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
#endif

// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<const RewritePattern *> opPatterns;
Expand All @@ -150,7 +158,9 @@ LogicalResult PatternApplicator::matchAndRewrite(
// operation type in an interleaved fashion.
unsigned opIt = 0, opE = opPatterns.size();
unsigned anyIt = 0, anyE = anyOpPatterns.size();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
unsigned pdlIt = 0, pdlE = pdlMatches.size();
#endif
LogicalResult result = failure();
do {
// Find the next pattern with the highest benefit.
Expand All @@ -168,6 +178,7 @@ LogicalResult PatternApplicator::matchAndRewrite(
bestPattern = anyOpPatterns[anyIt];
}

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
const PDLByteCode::MatchResult *pdlMatch = nullptr;
/// PDL patterns.
if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
Expand All @@ -176,6 +187,7 @@ LogicalResult PatternApplicator::matchAndRewrite(
pdlMatch = &pdlMatches[pdlIt];
bestPattern = pdlMatch->pattern;
}
#endif

if (!bestPattern)
break;
Expand All @@ -200,10 +212,13 @@ LogicalResult PatternApplicator::matchAndRewrite(
// pattern.
Operation *dumpRootOp = getDumpRootOp(op);
#endif
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
if (pdlMatch) {
result =
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
} else {
} else
#endif
{
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
<< bestPattern->getDebugName() << "\"\n");

Expand Down Expand Up @@ -234,7 +249,9 @@ LogicalResult PatternApplicator::matchAndRewrite(
break;
} while (true);

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
if (mutableByteCodeState)
mutableByteCodeState->cleanupAfterMatchAndRewrite();
#endif
return result;
}
Loading