Skip to content

Commit 4a524ae

Browse files
committed
Guard against invalid block erasure. Support forwarding to null listeners.
1 parent 63d0a9b commit 4a524ae

File tree

5 files changed

+84
-23
lines changed

5 files changed

+84
-23
lines changed

mlir/docs/PatternRewriter.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,17 @@ The driver performs a post-order traversal. Note that it walks regions of the
331331
given op but does not visit the op.
332332

333333
This driver does not (re)visit modified or newly replaced ops, and does not
334-
allow for progressive rewrites of the same op. Op erasure is only supported for
335-
the currently matched op. If your pattern set requires these, consider using the
336-
Greedy Pattern Rewrite Driver instead, at the expense of extra overhead.
334+
allow for progressive rewrites of the same op. Op and block erasure is only
335+
supported for the currently matched op and its descendant. If your pattern
336+
set requires these, consider using the Greedy Pattern Rewrite Driver instead,
337+
at the expense of extra overhead.
337338

338339
This driver is exposed using the `walkAndApplyPatterns` function.
339340

341+
Note: This driver listens for IR changes via the callbacks provided by
342+
`RewriterBase`. It is important that patterns announce all IR changes to the
343+
rewriter and do not bypass the rewriter API by modifying ops directly.
344+
340345
#### Debugging
341346

342347
You can debug the Walk Pattern Rewrite Driver by passing the

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder {
461461
/// struct can be used as a base to create listener chains, so that multiple
462462
/// listeners can be notified of IR changes.
463463
struct ForwardingListener : public RewriterBase::Listener {
464-
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
464+
ForwardingListener(OpBuilder::Listener *listener)
465+
: listener(listener),
466+
rewriteListener(
467+
dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
465468

466469
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
467-
listener->notifyOperationInserted(op, previous);
470+
if (listener)
471+
listener->notifyOperationInserted(op, previous);
468472
}
469473
void notifyBlockInserted(Block *block, Region *previous,
470474
Region::iterator previousIt) override {
471-
listener->notifyBlockInserted(block, previous, previousIt);
475+
if (listener)
476+
listener->notifyBlockInserted(block, previous, previousIt);
472477
}
473478
void notifyBlockErased(Block *block) override {
474-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
479+
if (rewriteListener)
475480
rewriteListener->notifyBlockErased(block);
476481
}
477482
void notifyOperationModified(Operation *op) override {
478-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
483+
if (rewriteListener)
479484
rewriteListener->notifyOperationModified(op);
480485
}
481486
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
482-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
487+
if (rewriteListener)
483488
rewriteListener->notifyOperationReplaced(op, newOp);
484489
}
485490
void notifyOperationReplaced(Operation *op,
486491
ValueRange replacement) override {
487-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
492+
if (rewriteListener)
488493
rewriteListener->notifyOperationReplaced(op, replacement);
489494
}
490495
void notifyOperationErased(Operation *op) override {
491-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
496+
if (rewriteListener)
492497
rewriteListener->notifyOperationErased(op);
493498
}
494499
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
495-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
500+
if (rewriteListener)
496501
rewriteListener->notifyPatternBegin(pattern, op);
497502
}
498503
void notifyPatternEnd(const Pattern &pattern,
499504
LogicalResult status) override {
500-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
505+
if (rewriteListener)
501506
rewriteListener->notifyPatternEnd(pattern, status);
502507
}
503508
void notifyMatchFailure(
504509
Location loc,
505510
function_ref<void(Diagnostic &)> reasonCallback) override {
506-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
511+
if (rewriteListener)
507512
rewriteListener->notifyMatchFailure(loc, reasonCallback);
508513
}
509514

510515
private:
511516
OpBuilder::Listener *listener;
517+
RewriterBase::Listener *rewriteListener;
512518
};
513519

514520
/// Move the blocks that belong to "region" before the given position in

mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,34 @@ struct WalkAndApplyPatternsAction final
3535
};
3636

3737
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38-
// Forwarding listener to guard against unsupported erasures. Because we use
39-
// walk-based pattern application, erasing the op from the *next* iteration
40-
// (e.g., a user of the visited op) is not valid.
41-
// Note that this is only used with expensive pattern API checks.
38+
// Forwarding listener to guard against unsupported erasures of non-descendant
39+
// ops/blocks. Because we use walk-based pattern application, erasing the
40+
// op/block from the *next* iteration (e.g., a user of the visited op) is not
41+
// valid. Note that this is only used with expensive pattern API checks.
4242
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
4343
using RewriterBase::ForwardingListener::ForwardingListener;
4444

4545
void notifyOperationErased(Operation *op) override {
46-
if (op != visitedOp)
47-
llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
48-
"erasure is only supported for matched ops");
49-
46+
checkErasure(op);
5047
ForwardingListener::notifyOperationErased(op);
5148
}
5249

50+
void notifyBlockErased(Block *block) override {
51+
checkErasure(block->getParentOp());
52+
ForwardingListener::notifyBlockErased(block);
53+
}
54+
55+
void checkErasure(Operation *op) const {
56+
Operation *ancestorOp = op;
57+
while (ancestorOp && ancestorOp != visitedOp)
58+
ancestorOp = ancestorOp->getParentOp();
59+
60+
if (ancestorOp != visitedOp)
61+
llvm::report_fatal_error(
62+
"unsupported erased in WalkPatternRewriter; "
63+
"erasure is only supported for matched ops and their descendants");
64+
}
65+
5366
Operation *visitedOp = nullptr;
5467
};
5568
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

mlir/test/IR/test-walk-pattern-rewrite-driver.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,17 @@ func.func @replace_with_new_op() -> i32 {
105105
%res = arith.addi %a, %a : i32
106106
return %res : i32
107107
}
108+
109+
// Check that we can erase nested blocks.
110+
// CHECK-LABEL: func.func @erase_nested_block
111+
// CHECK: %[[RES:.+]] = "test.erase_first_block"
112+
// CHECK-NEXT: foo.bar
113+
// CHECK: return %[[RES]]
114+
func.func @erase_nested_block() -> i32 {
115+
%a = "test.erase_first_block"() ({
116+
"foo.foo"() : () -> ()
117+
^bb1:
118+
"foo.bar"() : () -> ()
119+
}): () -> (i32)
120+
return %a : i32
121+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,28 @@ class ReplaceWithNewOp : public RewritePattern {
342342
}
343343
};
344344

345+
/// Erases the first child block of the matched "test.erase_first_block"
346+
/// operation.
347+
class EraseFirstBlock : public RewritePattern {
348+
public:
349+
EraseFirstBlock(MLIRContext *context)
350+
: RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
351+
352+
LogicalResult matchAndRewrite(Operation *op,
353+
PatternRewriter &rewriter) const override {
354+
llvm::errs() << "Num regions: " << op->getNumRegions() << "\n";
355+
for (Region &r : op->getRegions()) {
356+
for (Block &b : r.getBlocks()) {
357+
rewriter.eraseBlock(&b);
358+
llvm::errs() << "Erasing block: " << b << "\n";
359+
return success();
360+
}
361+
}
362+
363+
return failure();
364+
}
365+
};
366+
345367
struct TestGreedyPatternDriver
346368
: public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
347369
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
@@ -608,7 +630,8 @@ struct TestWalkPatternDriver final
608630

609631
// Patterns for testing the WalkPatternRewriteDriver.
610632
patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
611-
MoveAfterParentOp, CloneOp, ReplaceWithNewOp>(&getContext());
633+
MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
634+
&getContext());
612635

613636
DumpNotifications dumpListener;
614637
walkAndApplyPatterns(getOperation(), std::move(patterns),

0 commit comments

Comments
 (0)