@@ -848,7 +848,7 @@ namespace detail {
848848struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
849849 explicit ConversionPatternRewriterImpl (MLIRContext *ctx,
850850 const ConversionConfig &config)
851- : context(ctx), eraseRewriter(ctx), config(config) {}
851+ : context(ctx), config(config) {}
852852
853853 // ===--------------------------------------------------------------------===//
854854 // State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
981981 // / no new IR is created between calls to `eraseOp`/`eraseBlock`.
982982 struct SingleEraseRewriter : public RewriterBase , RewriterBase::Listener {
983983 public:
984- SingleEraseRewriter (MLIRContext *context)
985- : RewriterBase(context, /* listener=*/ this ) {}
984+ SingleEraseRewriter (
985+ MLIRContext *context,
986+ std::function<void (Operation *)> opErasedCallback = nullptr )
987+ : RewriterBase(context, /* listener=*/ this ),
988+ opErasedCallback (opErasedCallback) {}
986989
987990 // / Erase the given op (unless it was already erased).
988991 void eraseOp (Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10031006
10041007 bool wasErased (void *ptr) const { return erased.contains (ptr); }
10051008
1006- void notifyOperationErased (Operation *op) override { erased.insert (op); }
1009+ void notifyOperationErased (Operation *op) override {
1010+ erased.insert (op);
1011+ if (opErasedCallback)
1012+ opErasedCallback (op);
1013+ }
10071014
10081015 void notifyBlockErased (Block *block) override { erased.insert (block); }
10091016
10101017 private:
10111018 // / Pointers to all erased operations and blocks.
10121019 DenseSet<void *> erased;
1020+
1021+ // / A callback that is invoked when an operation is erased.
1022+ std::function<void (Operation *)> opErasedCallback;
10131023 };
10141024
10151025 // ===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10191029 // / MLIR context.
10201030 MLIRContext *context;
10211031
1022- // / A rewriter that keeps track of ops/block that were already erased and
1023- // / skips duplicate op/block erasures. This rewriter is used during the
1024- // / "cleanup" phase.
1025- SingleEraseRewriter eraseRewriter;
1026-
10271032 // Mapping between replaced values that differ in type. This happens when
10281033 // replacing a value with one of a different type.
10291034 ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11951200 rewrites[i]->commit (rewriter);
11961201
11971202 // Clean up all rewrites.
1203+ SingleEraseRewriter eraseRewriter (
1204+ context, /* opErasedCallback=*/ [&](Operation *op) {
1205+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1206+ unresolvedMaterializations.erase (castOp);
1207+ });
11981208 for (auto &rewrite : rewrites)
11991209 rewrite->cleanup (eraseRewriter);
12001210}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27142724 SmallVector<UnrealizedConversionCastOp> allCastOps;
27152725 const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
27162726 &materializations = rewriterImpl.unresolvedMaterializations ;
2717- for (auto it : materializations) {
2718- if (rewriterImpl.eraseRewriter .wasErased (it.first ))
2719- continue ;
2727+ for (auto it : materializations)
27202728 allCastOps.push_back (it.first );
2721- }
27222729
27232730 // Reconcile all UnrealizedConversionCastOps that were inserted by the
27242731 // dialect conversion frameworks. (Not the one that were inserted by
0 commit comments