Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 89 additions & 10 deletions mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"

#define DEBUG_TYPE "walk-rewriter"
Expand Down Expand Up @@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();

// Iterator on all reachable operations in the region.
// Also keep track if we visited the nested regions of the current op
// already to drive the post-order traversal.
struct RegionReachableOpIterator {
RegionReachableOpIterator(Region *region) : region(region) {
regionIt = region->begin();
if (regionIt != region->end())
blockIt = regionIt->begin();
}
// Advance the iterator to the next reachable operation.
void advance() {
assert(regionIt != region->end());
hasVisitedRegions = false;
if (blockIt == regionIt->end()) {
++regionIt;
if (regionIt != region->end())
blockIt = regionIt->begin();
return;
}
++blockIt;
if (blockIt != regionIt->end()) {
LDBG() << "Incrementing block iterator, next op: "
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
}
}
// The region we're iterating over.
Region *region;
// The Block currently being iterated over.
Region::iterator regionIt;
// The Operation currently being iterated over.
Block::iterator blockIt;
// Whether we've visited the nested regions of the current op already.
bool hasVisitedRegions = false;
};

// Worklist of regions to visit to drive the post-order traversal.
SmallVector<RegionReachableOpIterator> worklist;

LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
// Perform a post-order traversal of the regions, visiting each
// reachable operation.
for (Region &region : op->getRegions()) {
region.walk([&](Operation *visitedOp) {
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";);
assert(worklist.empty());
if (region.empty())
continue;

// Prime the worklist with the entry block of this region.
worklist.push_back({&region});
while (!worklist.empty()) {
RegionReachableOpIterator &it = worklist.back();
if (it.regionIt == it.region->end()) {
// We're done with this region.
worklist.pop_back();
continue;
}
if (it.blockIt == it.regionIt->end()) {
// We're done with this block.
it.advance();
continue;
}
Operation *op = &*it.blockIt;
// If we haven't visited the nested regions of this op yet,
// enqueue them.
if (!it.hasVisitedRegions) {
it.hasVisitedRegions = true;
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
if (nestedRegion.empty())
continue;
worklist.push_back({&nestedRegion});
}
}
// If we're not at the back of the worklist, we've enqueued some
// nested region for processing. We'll come back to this op later
// (post-order)
if (&it != &worklist.back())
continue;

// Preemptively increment the iterator, in case the current op
// would be erased.
it.advance();

LDBG() << "Visiting op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = visitedOp;
erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
}
});
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
LDBG() << "\tOp matched and rewritten";
}
}
},
{op});
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
}

// Check that the driver handles rewriter.moveAfter. In this case, we expect
// the moved op to be visited only once since walk uses `make_early_inc_range`.
// the moved op to be visited twice.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
Expand Down