@@ -458,6 +458,22 @@ struct LinalgDetensorize
458458 }
459459 };
460460
461+ // / A listener that forwards notifyBlockErased and notifyOperationErased to
462+ // / the given callbacks.
463+ struct CallbackListener : public RewriterBase ::Listener {
464+ CallbackListener (std::function<void (Operation *op)> onOperationErased,
465+ std::function<void (Block *block)> onBlockErased)
466+ : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
467+
468+ void notifyBlockErased (Block *block) override { onBlockErased (block); }
469+ void notifyOperationErased (Operation *op) override {
470+ onOperationErased (op);
471+ }
472+
473+ std::function<void (Operation *op)> onOperationErased;
474+ std::function<void (Block *block)> onBlockErased;
475+ };
476+
461477 void runOnOperation () override {
462478 MLIRContext *context = &getContext ();
463479 DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
551567 populateBranchOpInterfaceTypeConversionPattern (patterns, typeConverter,
552568 shouldConvertBranchOperand);
553569
554- if (failed (
555- applyFullConversion (getOperation (), target, std::move (patterns))))
570+ ConversionConfig config;
571+ auto onOperationErased = [&](Operation *op) {
572+ opsToDetensor.erase (op);
573+ detensorableBranchOps.erase (op);
574+ };
575+ auto onBlockErased = [&](Block *block) {
576+ for (BlockArgument arg : block->getArguments ()) {
577+ blockArgsToDetensor.erase (arg);
578+ }
579+ };
580+ CallbackListener listener (onOperationErased, onBlockErased);
581+
582+ config.listener = &listener;
583+ config.allowPatternRollback = false ;
584+ if (failed (applyFullConversion (getOperation (), target, std::move (patterns),
585+ config)))
556586 signalPassFailure ();
557587
558588 RewritePatternSet canonPatterns (context);
0 commit comments