Skip to content

Conversation

@PragmaTwice
Copy link
Member

In MLIR, both PatternRewriter and OpBuilder typically have corresponding Listeners that monitor changes to the IR. For example, when a new operation is created, the listener’s notifyOperationInserted method is invoked to inform the rewriter of this change. (This is also why, in C++ rewrite patterns, one should use PatternRewriter::create instead of OpTy::create, since the latter does not trigger notifyOperationInserted.)

However, in Python, the listener methods are currently not invoked when executing rewrite patterns. While this may not always affect the outcome, it violates the semantics defined in MLIR’s C++ implementation. At present, rewrite patterns in Python directly construct new operations using the TableGen-generated op constructors, such as arith.addi(lhs, rhs).
Although we could introduce an API like rewriter.create("arith.addi", operands=[lhs, rhs], results=...), it would be less intuitive, and users might still bypass it and instantiate ops directly, which would again skip listener notifications.

In this PR, we adopt an approach similar to how context, location, and insertion point are managed: we maintain a stack of listeners. When arith.addi or Operation.create is called, the top listener on the stack is retrieved, and notifyOperationInserted is invoked automatically. This allows users to construct operations in the usual way, while ensuring that listeners are properly notified.

@PragmaTwice PragmaTwice changed the title [MLIR][Python] Call notifyOperationInserted while constructing new op in rewrite patterns [MLIR][Python] Call notifyOperationInserted while constructing new op in rewrite patterns Oct 16, 2025
@PragmaTwice PragmaTwice marked this pull request as ready for review October 16, 2025 17:14
@llvmbot llvmbot added the mlir label Oct 16, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

Changes

In MLIR, both PatternRewriter and OpBuilder typically have corresponding Listeners that monitor changes to the IR. For example, when a new operation is created, the listener’s notifyOperationInserted method is invoked to inform the rewriter of this change. (This is also why, in C++ rewrite patterns, one should use PatternRewriter::create instead of OpTy::create, since the latter does not trigger notifyOperationInserted.)

However, in Python, the listener methods are currently not invoked when executing rewrite patterns. While this may not always affect the outcome, it violates the semantics defined in MLIR’s C++ implementation. At present, rewrite patterns in Python directly construct new operations using the TableGen-generated op constructors, such as arith.addi(lhs, rhs).
Although we could introduce an API like rewriter.create("arith.addi", operands=[lhs, rhs], results=...), it would be less intuitive, and users might still bypass it and instantiate ops directly, which would again skip listener notifications.

In this PR, we adopt an approach similar to how context, location, and insertion point are managed: we maintain a stack of listeners. When arith.addi or Operation.create is called, the top listener on the stack is retrieved, and notifyOperationInserted is invoked automatically. This allows users to construct operations in the usual way, while ensuring that listeners are properly notified.


Full diff: https://github.com/llvm/llvm-project/pull/163694.diff

6 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+10)
  • (modified) mlir/include/mlir/CAPI/Rewrite.h (+1)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+69-6)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+39-3)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+16-1)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+17)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84cd1d89..f42aaff4d3c19 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -39,6 +39,7 @@ DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
+DEFINE_C_API_STRUCT(MlirRewriterBaseListener, void);
 
 //===----------------------------------------------------------------------===//
 /// RewriterBase API inherited from OpBuilder
@@ -48,6 +49,15 @@ DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
 MLIR_CAPI_EXPORTED MlirContext
 mlirRewriterBaseGetContext(MlirRewriterBase rewriter);
 
+/// Get the listener of the rewriter.
+MLIR_CAPI_EXPORTED MlirRewriterBaseListener
+mlirRewriterBaseGetListener(MlirRewriterBase rewriter);
+
+/// Notify the listener that the specified operation was inserted.
+MLIR_CAPI_EXPORTED void mlirRewriterBaseListenerNotifyOperationInserted(
+    MlirRewriterBaseListener listener, MlirOperation op,
+    MlirOperation insertionPoint);
+
 //===----------------------------------------------------------------------===//
 /// Insertion points methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h
index 9c96d354d4fc9..357f2ae5a4418 100644
--- a/mlir/include/mlir/CAPI/Rewrite.h
+++ b/mlir/include/mlir/CAPI/Rewrite.h
@@ -26,6 +26,7 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
 DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet,
                          mlir::FrozenRewritePatternSet)
 DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter)
+DEFINE_C_API_PTR_METHODS(MlirRewriterBaseListener, mlir::RewriterBase::Listener)
 
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7b1710656243a..66d008445227b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -810,11 +810,11 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
 }
 
 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
-                                nb::object insertionPoint,
-                                nb::object location) {
+                                nb::object insertionPoint, nb::object location,
+                                nb::object listener) {
   auto &stack = getStack();
   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
-                     std::move(location));
+                     std::move(location), std::move(listener));
   // If the new stack has more than one entry and the context of the new top
   // entry matches the previous, copy the insertionPoint and location from the
   // previous entry if missing from the new top entry.
@@ -827,6 +827,8 @@ void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
         current.insertionPoint = prev.insertionPoint;
       if (!current.location)
         current.location = prev.location;
+      if (!current.listener)
+        current.listener = prev.listener;
     }
   }
 }
@@ -849,6 +851,12 @@ PyLocation *PyThreadContextEntry::getLocation() {
   return nb::cast<PyLocation *>(location);
 }
 
+PyRewriterBaseListener *PyThreadContextEntry::getListener() {
+  if (!listener)
+    return nullptr;
+  return nb::cast<PyRewriterBaseListener *>(listener);
+}
+
 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
   auto *tos = getTopOfStack();
   return tos ? tos->getContext() : nullptr;
@@ -864,10 +872,16 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() {
   return tos ? tos->getLocation() : nullptr;
 }
 
+PyRewriterBaseListener *PyThreadContextEntry::getDefaultListener() {
+  auto *tos = getTopOfStack();
+  return tos ? tos->getListener() : nullptr;
+}
+
 nb::object PyThreadContextEntry::pushContext(nb::object context) {
   push(FrameKind::Context, /*context=*/context,
        /*insertionPoint=*/nb::object(),
-       /*location=*/nb::object());
+       /*location=*/nb::object(),
+       /*listener=*/nb::object());
   return context;
 }
 
@@ -890,7 +904,8 @@ PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
   push(FrameKind::InsertionPoint,
        /*context=*/contextObj,
        /*insertionPoint=*/insertionPointObj,
-       /*location=*/nb::object());
+       /*location=*/nb::object(),
+       /*listener=*/nb::object());
   return insertionPointObj;
 }
 
@@ -910,7 +925,8 @@ nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
   nb::object contextObj = location.getContext().getObject();
   push(FrameKind::Location, /*context=*/contextObj,
        /*insertionPoint=*/nb::object(),
-       /*location=*/locationObj);
+       /*location=*/locationObj,
+       /*listener=*/nb::object());
   return locationObj;
 }
 
@@ -924,6 +940,27 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
   stack.pop_back();
 }
 
+nb::object PyThreadContextEntry::pushListener(nb::object listenerObj) {
+  PyRewriterBaseListener &listener =
+      nb::cast<PyRewriterBaseListener &>(listenerObj);
+  nb::object contextObj = listener.getContext().getObject();
+  push(FrameKind::Location, /*context=*/contextObj,
+       /*insertionPoint=*/nb::object(),
+       /*location=*/nb::object(),
+       /*listener=*/listenerObj);
+  return listenerObj;
+}
+
+void PyThreadContextEntry::popListener(PyRewriterBaseListener &listener) {
+  auto &stack = getStack();
+  if (stack.empty())
+    throw std::runtime_error("Unbalanced Listener enter/exit");
+  auto &tos = stack.back();
+  if (tos.frameKind != FrameKind::Listener && tos.getListener() != &listener)
+    throw std::runtime_error("Unbalanced Listener enter/exit");
+  stack.pop_back();
+}
+
 //------------------------------------------------------------------------------
 // PyDiagnostic*
 //------------------------------------------------------------------------------
@@ -1417,6 +1454,10 @@ static void maybeInsertOperation(PyOperationRef &op,
     if (ip)
       ip->insert(*op.get());
   }
+  if (PyRewriterBaseListener *listener =
+          PyThreadContextEntry::getDefaultListener()) {
+    listener->notifyOperationInserted(*op.get());
+  }
 }
 
 nb::object PyOperation::create(std::string_view name,
@@ -2036,6 +2077,19 @@ PyOpView::PyOpView(const nb::object &operationObject)
     : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
       operationObject(operation.getRef().getObject()) {}
 
+//------------------------------------------------------------------------------
+// PyRewriterBaseListener.
+//------------------------------------------------------------------------------
+
+nb::object PyRewriterBaseListener::contextEnter(nb::object listener) {
+  return PyThreadContextEntry::pushListener(std::move(listener));
+}
+
+void PyRewriterBaseListener::contextExit(nb::handle excType, nb::handle excVal,
+                                         nb::handle excTb) {
+  PyThreadContextEntry::popListener(*this);
+}
+
 //------------------------------------------------------------------------------
 // PyInsertionPoint.
 //------------------------------------------------------------------------------
@@ -3961,6 +4015,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           "Returns the list of Block predecessors.");
 
+  //----------------------------------------------------------------------------
+  // Mapping of PyRewriterBaseListener.
+  //----------------------------------------------------------------------------
+  nb::class_<PyRewriterBaseListener>(m, "RewriterBaseListener")
+      .def("__enter__", &PyRewriterBaseListener::contextEnter)
+      .def("__exit__", &PyRewriterBaseListener::contextExit,
+           nb::arg("exc_type").none(), nb::arg("exc_value").none(),
+           nb::arg("traceback").none());
+
   //----------------------------------------------------------------------------
   // Mapping of PyInsertionPoint.
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index e706be3b4d32a..9985af0a448a6 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -46,6 +46,7 @@ class PyOperationBase;
 class PyType;
 class PySymbolTable;
 class PyValue;
+class PyRewriterBaseListener;
 
 /// Template for a reference to a concrete type which captures a python
 /// reference to its underlying python object.
@@ -115,13 +116,15 @@ class PyThreadContextEntry {
     Context,
     InsertionPoint,
     Location,
+    Listener,
   };
 
   PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
                        nanobind::object insertionPoint,
-                       nanobind::object location)
+                       nanobind::object location, nanobind::object listener)
       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
-        location(std::move(location)), frameKind(frameKind) {}
+        location(std::move(location)), listener(std::move(listener)),
+        frameKind(frameKind) {}
 
   /// Gets the top of stack context and return nullptr if not defined.
   static PyMlirContext *getDefaultContext();
@@ -132,9 +135,12 @@ class PyThreadContextEntry {
   /// Gets the top of stack location and returns nullptr if not defined.
   static PyLocation *getDefaultLocation();
 
+  static PyRewriterBaseListener *getDefaultListener();
+
   PyMlirContext *getContext();
   PyInsertionPoint *getInsertionPoint();
   PyLocation *getLocation();
+  PyRewriterBaseListener *getListener();
   FrameKind getFrameKind() { return frameKind; }
 
   /// Stack management.
@@ -145,13 +151,16 @@ class PyThreadContextEntry {
   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
   static nanobind::object pushLocation(nanobind::object location);
   static void popLocation(PyLocation &location);
+  static nanobind::object pushListener(nanobind::object listener);
+  static void popListener(PyRewriterBaseListener &listener);
 
   /// Gets the thread local stack.
   static std::vector<PyThreadContextEntry> &getStack();
 
 private:
   static void push(FrameKind frameKind, nanobind::object context,
-                   nanobind::object insertionPoint, nanobind::object location);
+                   nanobind::object insertionPoint, nanobind::object location,
+                   nanobind::object listener);
 
   /// An object reference to the PyContext.
   nanobind::object context;
@@ -159,6 +168,8 @@ class PyThreadContextEntry {
   nanobind::object insertionPoint;
   /// An object reference to the current location.
   nanobind::object location;
+  /// An object reference to the current listener.
+  nanobind::object listener;
   // The kind of push that was performed.
   FrameKind frameKind;
 };
@@ -830,6 +841,31 @@ class PyBlock {
   MlirBlock block;
 };
 
+/// Wrapper around a MlirRewriterBaseListener.
+class PyRewriterBaseListener {
+public:
+  PyRewriterBaseListener(MlirRewriterBaseListener listener,
+                         PyMlirContextRef ctx)
+      : listener(listener), ctx(std::move(ctx)) {}
+
+  MlirRewriterBaseListener get() { return listener; }
+
+  void notifyOperationInserted(PyOperationBase &op) {
+    mlirRewriterBaseListenerNotifyOperationInserted(get(), op.getOperation(),
+                                                    MlirOperation{nullptr});
+  }
+
+  PyMlirContextRef getContext() { return ctx; }
+
+  static nanobind::object contextEnter(nanobind::object listener);
+  void contextExit(nanobind::handle excType, nanobind::handle excVal,
+                   nanobind::handle excTb);
+
+private:
+  MlirRewriterBaseListener listener;
+  PyMlirContextRef ctx;
+};
+
 /// An insertion point maintains a pointer to a Block and a reference operation.
 /// Calls to insert() will insert a new operation before the
 /// reference operation. If the reference operation is null, then appends to
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5ddb3fbbb1317..5512fb2377d60 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -18,6 +18,7 @@
 // clang-format on
 #include "mlir/Config/mlir-config.h"
 #include "nanobind/nanobind.h"
+#include "llvm/ADT/ScopeExit.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -45,6 +46,10 @@ class PyPatternRewriter {
     return PyInsertionPoint(PyOperation::forOperation(ctx, op));
   }
 
+  PyRewriterBaseListener getListener() {
+    return PyRewriterBaseListener(mlirRewriterBaseGetListener(base), ctx);
+  }
+
   void replaceOp(MlirOperation op, MlirOperation newOp) {
     mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
   }
@@ -202,7 +207,15 @@ class PyRewritePatternSet {
           PyMlirContext::forContext(mlirOperationGetContext(op));
       nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
 
-      nb::object res = f(opView, PyPatternRewriter(rewriter));
+      PyPatternRewriter pyRewriter(rewriter);
+      nb::object listener = nb::cast(pyRewriter.getListener());
+
+      listener.attr("__enter__")();
+      auto exit = llvm::make_scope_exit([listener] {
+        listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
+      });
+      nb::object res = f(opView, pyRewriter);
+
       return logicalResultFromObject(res);
     };
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(
@@ -234,6 +247,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       class_<PyPatternRewriter>(m, "PatternRewriter")
           .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
                        "The current insertion point of the PatternRewriter.")
+          .def_prop_ro("listener", &PyPatternRewriter::getListener,
+                       "The rewrite listener of the PatternRewriter.")
           .def(
               "replace_op",
               [](PyPatternRewriter &self, MlirOperation op,
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 46c329d8433b4..41b48bdfe5d6b 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -29,6 +29,23 @@ MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
   return wrap(unwrap(rewriter)->getContext());
 }
 
+MlirRewriterBaseListener
+mlirRewriterBaseGetListener(MlirRewriterBase rewriter) {
+  return wrap(
+      dyn_cast<RewriterBase::Listener>(unwrap(rewriter)->getListener()));
+}
+
+void mlirRewriterBaseListenerNotifyOperationInserted(
+    MlirRewriterBaseListener listener, MlirOperation op,
+    MlirOperation insertionPoint) {
+  OpBuilder::InsertPoint ip;
+  if (!mlirOperationIsNull(insertionPoint)) {
+    ip = OpBuilder::InsertPoint(unwrap(insertionPoint)->getBlock(),
+                                Block::iterator(unwrap(insertionPoint)));
+  }
+  return unwrap(listener)->notifyOperationInserted(unwrap(op), ip);
+}
+
 //===----------------------------------------------------------------------===//
 /// Insertion points methods
 //===----------------------------------------------------------------------===//

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Oct 16, 2025

Currently I haven't come up with a good idea to write test cases for this change yet, but the change is generally ready for review.

Comment on lines +211 to +216
nb::object listener = nb::cast(pyRewriter.getListener());

listener.attr("__enter__")();
auto exit = llvm::make_scope_exit([listener] {
listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
});
Copy link
Contributor

@makslevental makslevental Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think both the addition of PyRewriterBaseListener and the idea of somehow calling notify on listeners is a great idea. i think using it like this (in addition to adding it to the default threadcontextstack) is not a good idea. 2 possibilities:

  1. having a separate threadcontextstack just for listeners (but also - is this how listeners are actually composed? aren't there just like parent->child relationship? i don't remember right this second)
  2. passing in the listener to add (here) explicitly

also btw ADT isn't actually safe to use like this i think? i need to double check looks like IRAttributes already uses ScopeExit 🤷

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having a separate threadcontextstack just for listeners

yup it's possible and the code can be simpler by doing this. and even we don't need to make listener a python object.

is this how listeners are actually composed? aren't there just like parent->child relationship

we can just have a threadlocal current listener instead of a stack of listeners, but to allow users to do weird nested things like this:

def rewrite(op, rewriter):
    def rewrite2(op, rewriter): ...
    patterns = RewritePatternSet()
    patterns.add(.., rewrite2)
    ...

patterns = RewritePatternSet()
patterns.add(.., rewrite)

.. a stack can be more general (but i'm also not sure since this seems an anti-pattern..)

passing in the listener to add (here) explicitly

this can be a little hard since listener is like a field of PatternRewriter, and outside the callback of add we cannot access the rewriter easily.

Copy link
Contributor

@makslevental makslevental Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the "core" way to do this is ForwardingListener ie the listener "chains" have to actually be constructed at instantiation of the leaf listener (which then forwards up the chain or something like that). but that's too complicated for right now before anyone actually needs it (asks for it). let's leave off this stacks/chains of listeners for now.

this can be a little hard since listener is like a field of PatternRewriter

can't you play the same trick we play everywhere already: just put it into *userData? basically add an add overload that takes a listener and sticks it into *userData along with the callable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically i guess i'm saying the listener doesn't have to be on the PatternRewriter just because it's like that in cpp...

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll refactor this change soon.

basically add an add overload that takes a listener and sticks it into *userData along with the callable?

yeah I can get your point, but we cannot get the listener before we get the rewriter, and only after we enter the callback we can finally get the rewriter. (shown as an example below)

def f(op, rewriter):
  # now we can get the rewriter!
  rewriter.listener # can get the listener like this

patterns = RewritePatternSet()
patterns.add(arith.SomeOp, f) # we cannot get it here
frozen = patterns.freeze()

apply_and_fold_greedily(op, frozen) # I think the rewriter is constructed inside this method

the trick here is that, we need to get the listener inside every Operation.create(..) and wrappers like arith.addi(..), so we should maintain a global state for retrieving the current listener. (shown below)

def f(op, rewriter):
  # we save the rewriter.listener as the current listener (push)
  new_op = arith.addi(lhs, rhs) # we need to retrieve the current listener inside this function and call notifyOperationInserted
  ... # do something with new_op
  # in the end we undo the save (pop)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gimme a couple minutes to try your patch...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's a sketch of what i'm talking about (on top of your current commit):

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5512fb2377d6..e41a99be1473 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -189,37 +189,44 @@ public:
       mlirRewritePatternSetDestroy(set);
   }
 
+  struct UserData {
+    const nb::callable &matchAndRewriteCb;
+    nb::object listener;
+  };
+
   void add(MlirStringRef rootName, unsigned benefit,
-           const nb::callable &matchAndRewrite) {
+           const nb::callable &matchAndRewriteCb, nb::object listener) {
     MlirRewritePatternCallbacks callbacks;
     callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+      UserData *userData_ = static_cast<UserData *>(userData);
+      userData_->matchAndRewriteCb.inc_ref();
+      userData_->listener.inc_ref();
     };
     callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+      UserData *userData_ = static_cast<UserData *>(userData);
+      userData_->matchAndRewriteCb.dec_ref();
+      userData_->listener.dec_ref();
     };
     callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
                                    MlirPatternRewriter rewriter,
                                    void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
+      UserData *userData_ = static_cast<UserData *>(userData);
+      nb::handle f(userData_->matchAndRewriteCb);
 
       PyMlirContextRef ctx =
           PyMlirContext::forContext(mlirOperationGetContext(op));
       nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
 
       PyPatternRewriter pyRewriter(rewriter);
-      nb::object listener = nb::cast(pyRewriter.getListener());
-
-      listener.attr("__enter__")();
-      auto exit = llvm::make_scope_exit([listener] {
-        listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
-      });
-      nb::object res = f(opView, pyRewriter);
+      nb::object listener = userData_->listener;
+      nb::object res = f(opView, pyRewriter, listener);
 
       return logicalResultFromObject(res);
     };
+
+    UserData *userData_ = new UserData{matchAndRewriteCb, listener};
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(
-        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        rootName, benefit, ctx, callbacks, static_cast<void *>(userData_),
         /* nGeneratedNames */ 0,
         /* generatedNames */ nullptr);
     mlirRewritePatternSetAdd(set, pattern);
@@ -291,13 +298,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "add",
           [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
-             unsigned benefit) {
+             const nb::object &listener, unsigned benefit) {
             std::string opName =
                 nb::cast<std::string>(root.attr("OPERATION_NAME"));
             self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
-                     fn);
+                     fn, listener);
           },
-          "root"_a, "fn"_a, "benefit"_a = 1,
+          "root"_a, "fn"_a, "listener"_a, "benefit"_a = 1,
           "Add a new rewrite pattern on the given root operation with the "
           "callable as the matching and rewriting function and the given "
           "benefit.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your patch! I understand the code, but since we can already get the listener via
nb::object listener = nb::cast(pyRewriter.getListener())
without passing additional user data, is there any difference?

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I think I've got some clue about the inconsistency between our understanding.

I want to explain in a gradual way why we need a global state (or stack). Firstly, we can just throw these stack/state things if we write code like this:

def match_and_rewrite(op, rewriter):
  new_op = arith.add(op.lhs, op.rhs, listener=rewriter.listener) 
  # or just: new_op = arith.add(op.lhs, op.rhs, rewriter=rewriter)
  ...

Here we pass the listener (or the rewriter) into the op constructor so we don't need any global state, and then we can use this listener to call notify.. methods. But in the initial I want to avoid such passing, so some code like this:

def match_and_rewrite(op, rewriter):
  new_op = arith.add(op.lhs, op.rhs) # no listener/rewriter passing, so we need to retrieve it from some global state!
  ...

And then, obviously, we need some global state to know which is the current listener that we need to obtain in the op constructor to call notify methods.

If we want to avoid global state/stack maybe we can just pass the listener to op constructors? this may require some changes to tblgen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants