@@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder {
461461 // / struct can be used as a base to create listener chains, so that multiple
462462 // / listeners can be notified of IR changes.
463463 struct ForwardingListener : public RewriterBase ::Listener {
464- ForwardingListener (OpBuilder::Listener *listener) : listener(listener) {}
464+ ForwardingListener (OpBuilder::Listener *listener)
465+ : listener(listener),
466+ rewriteListener (
467+ dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
465468
466469 void notifyOperationInserted (Operation *op, InsertPoint previous) override {
467- listener->notifyOperationInserted (op, previous);
470+ if (listener)
471+ listener->notifyOperationInserted (op, previous);
468472 }
469473 void notifyBlockInserted (Block *block, Region *previous,
470474 Region::iterator previousIt) override {
471- listener->notifyBlockInserted (block, previous, previousIt);
475+ if (listener)
476+ listener->notifyBlockInserted (block, previous, previousIt);
472477 }
473478 void notifyBlockErased (Block *block) override {
474- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
479+ if (rewriteListener)
475480 rewriteListener->notifyBlockErased (block);
476481 }
477482 void notifyOperationModified (Operation *op) override {
478- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
483+ if (rewriteListener)
479484 rewriteListener->notifyOperationModified (op);
480485 }
481486 void notifyOperationReplaced (Operation *op, Operation *newOp) override {
482- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
487+ if (rewriteListener)
483488 rewriteListener->notifyOperationReplaced (op, newOp);
484489 }
485490 void notifyOperationReplaced (Operation *op,
486491 ValueRange replacement) override {
487- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
492+ if (rewriteListener)
488493 rewriteListener->notifyOperationReplaced (op, replacement);
489494 }
490495 void notifyOperationErased (Operation *op) override {
491- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
496+ if (rewriteListener)
492497 rewriteListener->notifyOperationErased (op);
493498 }
494499 void notifyPatternBegin (const Pattern &pattern, Operation *op) override {
495- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
500+ if (rewriteListener)
496501 rewriteListener->notifyPatternBegin (pattern, op);
497502 }
498503 void notifyPatternEnd (const Pattern &pattern,
499504 LogicalResult status) override {
500- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
505+ if (rewriteListener)
501506 rewriteListener->notifyPatternEnd (pattern, status);
502507 }
503508 void notifyMatchFailure (
504509 Location loc,
505510 function_ref<void (Diagnostic &)> reasonCallback) override {
506- if (auto * rewriteListener = dyn_cast<RewriterBase::Listener>(listener) )
511+ if (rewriteListener)
507512 rewriteListener->notifyMatchFailure (loc, reasonCallback);
508513 }
509514
510515 private:
511516 OpBuilder::Listener *listener;
517+ RewriterBase::Listener *rewriteListener;
512518 };
513519
514520 // / Move the blocks that belong to "region" before the given position in
0 commit comments