Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir-c/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
DEFINE_C_API_STRUCT(MlirRewriterBaseListener, void);

//===----------------------------------------------------------------------===//
/// RewriterBase API inherited from OpBuilder
Expand All @@ -48,6 +49,15 @@ DEFINE_C_API_STRUCT(MlirRewritePattern, const void);
MLIR_CAPI_EXPORTED MlirContext
mlirRewriterBaseGetContext(MlirRewriterBase rewriter);

/// Get the listener of the rewriter.
MLIR_CAPI_EXPORTED MlirRewriterBaseListener
mlirRewriterBaseGetListener(MlirRewriterBase rewriter);

/// Notify the listener that the specified operation was inserted.
MLIR_CAPI_EXPORTED void mlirRewriterBaseListenerNotifyOperationInserted(
MlirRewriterBaseListener listener, MlirOperation op,
MlirOperation insertionPoint);

//===----------------------------------------------------------------------===//
/// Insertion points methods
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/CAPI/Rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet)
DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet,
mlir::FrozenRewritePatternSet)
DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter)
DEFINE_C_API_PTR_METHODS(MlirRewriterBaseListener, mlir::RewriterBase::Listener)

#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule)
Expand Down
75 changes: 69 additions & 6 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,11 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
}

void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
nb::object insertionPoint,
nb::object location) {
nb::object insertionPoint, nb::object location,
nb::object listener) {
auto &stack = getStack();
stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
std::move(location));
std::move(location), std::move(listener));
// If the new stack has more than one entry and the context of the new top
// entry matches the previous, copy the insertionPoint and location from the
// previous entry if missing from the new top entry.
Expand All @@ -827,6 +827,8 @@ void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
current.insertionPoint = prev.insertionPoint;
if (!current.location)
current.location = prev.location;
if (!current.listener)
current.listener = prev.listener;
}
}
}
Expand All @@ -849,6 +851,12 @@ PyLocation *PyThreadContextEntry::getLocation() {
return nb::cast<PyLocation *>(location);
}

PyRewriterBaseListener *PyThreadContextEntry::getListener() {
if (!listener)
return nullptr;
return nb::cast<PyRewriterBaseListener *>(listener);
}

PyMlirContext *PyThreadContextEntry::getDefaultContext() {
auto *tos = getTopOfStack();
return tos ? tos->getContext() : nullptr;
Expand All @@ -864,10 +872,16 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() {
return tos ? tos->getLocation() : nullptr;
}

PyRewriterBaseListener *PyThreadContextEntry::getDefaultListener() {
auto *tos = getTopOfStack();
return tos ? tos->getListener() : nullptr;
}

nb::object PyThreadContextEntry::pushContext(nb::object context) {
push(FrameKind::Context, /*context=*/context,
/*insertionPoint=*/nb::object(),
/*location=*/nb::object());
/*location=*/nb::object(),
/*listener=*/nb::object());
return context;
}

Expand All @@ -890,7 +904,8 @@ PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
push(FrameKind::InsertionPoint,
/*context=*/contextObj,
/*insertionPoint=*/insertionPointObj,
/*location=*/nb::object());
/*location=*/nb::object(),
/*listener=*/nb::object());
return insertionPointObj;
}

Expand All @@ -910,7 +925,8 @@ nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
nb::object contextObj = location.getContext().getObject();
push(FrameKind::Location, /*context=*/contextObj,
/*insertionPoint=*/nb::object(),
/*location=*/locationObj);
/*location=*/locationObj,
/*listener=*/nb::object());
return locationObj;
}

Expand All @@ -924,6 +940,27 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
stack.pop_back();
}

nb::object PyThreadContextEntry::pushListener(nb::object listenerObj) {
PyRewriterBaseListener &listener =
nb::cast<PyRewriterBaseListener &>(listenerObj);
nb::object contextObj = listener.getContext().getObject();
push(FrameKind::Location, /*context=*/contextObj,
/*insertionPoint=*/nb::object(),
/*location=*/nb::object(),
/*listener=*/listenerObj);
return listenerObj;
}

void PyThreadContextEntry::popListener(PyRewriterBaseListener &listener) {
auto &stack = getStack();
if (stack.empty())
throw std::runtime_error("Unbalanced Listener enter/exit");
auto &tos = stack.back();
if (tos.frameKind != FrameKind::Listener && tos.getListener() != &listener)
throw std::runtime_error("Unbalanced Listener enter/exit");
stack.pop_back();
}

//------------------------------------------------------------------------------
// PyDiagnostic*
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -1417,6 +1454,10 @@ static void maybeInsertOperation(PyOperationRef &op,
if (ip)
ip->insert(*op.get());
}
if (PyRewriterBaseListener *listener =
PyThreadContextEntry::getDefaultListener()) {
listener->notifyOperationInserted(*op.get());
}
}

nb::object PyOperation::create(std::string_view name,
Expand Down Expand Up @@ -2036,6 +2077,19 @@ PyOpView::PyOpView(const nb::object &operationObject)
: operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
operationObject(operation.getRef().getObject()) {}

//------------------------------------------------------------------------------
// PyRewriterBaseListener.
//------------------------------------------------------------------------------

nb::object PyRewriterBaseListener::contextEnter(nb::object listener) {
return PyThreadContextEntry::pushListener(std::move(listener));
}

void PyRewriterBaseListener::contextExit(nb::handle excType, nb::handle excVal,
nb::handle excTb) {
PyThreadContextEntry::popListener(*this);
}

//------------------------------------------------------------------------------
// PyInsertionPoint.
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3961,6 +4015,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Returns the list of Block predecessors.");

//----------------------------------------------------------------------------
// Mapping of PyRewriterBaseListener.
//----------------------------------------------------------------------------
nb::class_<PyRewriterBaseListener>(m, "RewriterBaseListener")
.def("__enter__", &PyRewriterBaseListener::contextEnter)
.def("__exit__", &PyRewriterBaseListener::contextExit,
nb::arg("exc_type").none(), nb::arg("exc_value").none(),
nb::arg("traceback").none());

//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
//----------------------------------------------------------------------------
Expand Down
42 changes: 39 additions & 3 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class PyOperationBase;
class PyType;
class PySymbolTable;
class PyValue;
class PyRewriterBaseListener;

/// Template for a reference to a concrete type which captures a python
/// reference to its underlying python object.
Expand Down Expand Up @@ -115,13 +116,15 @@ class PyThreadContextEntry {
Context,
InsertionPoint,
Location,
Listener,
};

PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
nanobind::object insertionPoint,
nanobind::object location)
nanobind::object location, nanobind::object listener)
: context(std::move(context)), insertionPoint(std::move(insertionPoint)),
location(std::move(location)), frameKind(frameKind) {}
location(std::move(location)), listener(std::move(listener)),
frameKind(frameKind) {}

/// Gets the top of stack context and return nullptr if not defined.
static PyMlirContext *getDefaultContext();
Expand All @@ -132,9 +135,12 @@ class PyThreadContextEntry {
/// Gets the top of stack location and returns nullptr if not defined.
static PyLocation *getDefaultLocation();

static PyRewriterBaseListener *getDefaultListener();

PyMlirContext *getContext();
PyInsertionPoint *getInsertionPoint();
PyLocation *getLocation();
PyRewriterBaseListener *getListener();
FrameKind getFrameKind() { return frameKind; }

/// Stack management.
Expand All @@ -145,20 +151,25 @@ class PyThreadContextEntry {
static void popInsertionPoint(PyInsertionPoint &insertionPoint);
static nanobind::object pushLocation(nanobind::object location);
static void popLocation(PyLocation &location);
static nanobind::object pushListener(nanobind::object listener);
static void popListener(PyRewriterBaseListener &listener);

/// Gets the thread local stack.
static std::vector<PyThreadContextEntry> &getStack();

private:
static void push(FrameKind frameKind, nanobind::object context,
nanobind::object insertionPoint, nanobind::object location);
nanobind::object insertionPoint, nanobind::object location,
nanobind::object listener);

/// An object reference to the PyContext.
nanobind::object context;
/// An object reference to the current insertion point.
nanobind::object insertionPoint;
/// An object reference to the current location.
nanobind::object location;
/// An object reference to the current listener.
nanobind::object listener;
// The kind of push that was performed.
FrameKind frameKind;
};
Expand Down Expand Up @@ -830,6 +841,31 @@ class PyBlock {
MlirBlock block;
};

/// Wrapper around a MlirRewriterBaseListener.
class PyRewriterBaseListener {
public:
PyRewriterBaseListener(MlirRewriterBaseListener listener,
PyMlirContextRef ctx)
: listener(listener), ctx(std::move(ctx)) {}

MlirRewriterBaseListener get() { return listener; }

void notifyOperationInserted(PyOperationBase &op) {
mlirRewriterBaseListenerNotifyOperationInserted(get(), op.getOperation(),
MlirOperation{nullptr});
}

PyMlirContextRef getContext() { return ctx; }

static nanobind::object contextEnter(nanobind::object listener);
void contextExit(nanobind::handle excType, nanobind::handle excVal,
nanobind::handle excTb);

private:
MlirRewriterBaseListener listener;
PyMlirContextRef ctx;
};

/// An insertion point maintains a pointer to a Block and a reference operation.
/// Calls to insert() will insert a new operation before the
/// reference operation. If the reference operation is null, then appends to
Expand Down
17 changes: 16 additions & 1 deletion mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
#include "llvm/ADT/ScopeExit.h"

namespace nb = nanobind;
using namespace mlir;
Expand Down Expand Up @@ -45,6 +46,10 @@ class PyPatternRewriter {
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}

PyRewriterBaseListener getListener() {
return PyRewriterBaseListener(mlirRewriterBaseGetListener(base), ctx);
}

void replaceOp(MlirOperation op, MlirOperation newOp) {
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
}
Expand Down Expand Up @@ -202,7 +207,15 @@ class PyRewritePatternSet {
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();

nb::object res = f(opView, PyPatternRewriter(rewriter));
PyPatternRewriter pyRewriter(rewriter);
nb::object listener = nb::cast(pyRewriter.getListener());

listener.attr("__enter__")();
auto exit = llvm::make_scope_exit([listener] {
listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
});
Comment on lines +211 to +216
Copy link
Contributor

@makslevental makslevental Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think both the addition of PyRewriterBaseListener and the idea of somehow calling notify on listeners is a great idea. i think using it like this (in addition to adding it to the default threadcontextstack) is not a good idea. 2 possibilities:

  1. having a separate threadcontextstack just for listeners (but also - is this how listeners are actually composed? aren't there just like parent->child relationship? i don't remember right this second)
  2. passing in the listener to add (here) explicitly

also btw ADT isn't actually safe to use like this i think? i need to double check looks like IRAttributes already uses ScopeExit 🤷

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having a separate threadcontextstack just for listeners

yup it's possible and the code can be simpler by doing this. and even we don't need to make listener a python object.

is this how listeners are actually composed? aren't there just like parent->child relationship

we can just have a threadlocal current listener instead of a stack of listeners, but to allow users to do weird nested things like this:

def rewrite(op, rewriter):
    def rewrite2(op, rewriter): ...
    patterns = RewritePatternSet()
    patterns.add(.., rewrite2)
    ...

patterns = RewritePatternSet()
patterns.add(.., rewrite)

.. a stack can be more general (but i'm also not sure since this seems an anti-pattern..)

passing in the listener to add (here) explicitly

this can be a little hard since listener is like a field of PatternRewriter, and outside the callback of add we cannot access the rewriter easily.

Copy link
Contributor

@makslevental makslevental Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the "core" way to do this is ForwardingListener ie the listener "chains" have to actually be constructed at instantiation of the leaf listener (which then forwards up the chain or something like that). but that's too complicated for right now before anyone actually needs it (asks for it). let's leave off this stacks/chains of listeners for now.

this can be a little hard since listener is like a field of PatternRewriter

can't you play the same trick we play everywhere already: just put it into *userData? basically add an add overload that takes a listener and sticks it into *userData along with the callable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically i guess i'm saying the listener doesn't have to be on the PatternRewriter just because it's like that in cpp...

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll refactor this change soon.

basically add an add overload that takes a listener and sticks it into *userData along with the callable?

yeah I can get your point, but we cannot get the listener before we get the rewriter, and only after we enter the callback we can finally get the rewriter. (shown as an example below)

def f(op, rewriter):
  # now we can get the rewriter!
  rewriter.listener # can get the listener like this

patterns = RewritePatternSet()
patterns.add(arith.SomeOp, f) # we cannot get it here
frozen = patterns.freeze()

apply_and_fold_greedily(op, frozen) # I think the rewriter is constructed inside this method

the trick here is that, we need to get the listener inside every Operation.create(..) and wrappers like arith.addi(..), so we should maintain a global state for retrieving the current listener. (shown below)

def f(op, rewriter):
  # we save the rewriter.listener as the current listener (push)
  new_op = arith.addi(lhs, rhs) # we need to retrieve the current listener inside this function and call notifyOperationInserted
  ... # do something with new_op
  # in the end we undo the save (pop)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gimme a couple minutes to try your patch...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's a sketch of what i'm talking about (on top of your current commit):

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 5512fb2377d6..e41a99be1473 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -189,37 +189,44 @@ public:
       mlirRewritePatternSetDestroy(set);
   }
 
+  struct UserData {
+    const nb::callable &matchAndRewriteCb;
+    nb::object listener;
+  };
+
   void add(MlirStringRef rootName, unsigned benefit,
-           const nb::callable &matchAndRewrite) {
+           const nb::callable &matchAndRewriteCb, nb::object listener) {
     MlirRewritePatternCallbacks callbacks;
     callbacks.construct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).inc_ref();
+      UserData *userData_ = static_cast<UserData *>(userData);
+      userData_->matchAndRewriteCb.inc_ref();
+      userData_->listener.inc_ref();
     };
     callbacks.destruct = [](void *userData) {
-      nb::handle(static_cast<PyObject *>(userData)).dec_ref();
+      UserData *userData_ = static_cast<UserData *>(userData);
+      userData_->matchAndRewriteCb.dec_ref();
+      userData_->listener.dec_ref();
     };
     callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
                                    MlirPatternRewriter rewriter,
                                    void *userData) -> MlirLogicalResult {
-      nb::handle f(static_cast<PyObject *>(userData));
+      UserData *userData_ = static_cast<UserData *>(userData);
+      nb::handle f(userData_->matchAndRewriteCb);
 
       PyMlirContextRef ctx =
           PyMlirContext::forContext(mlirOperationGetContext(op));
       nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
 
       PyPatternRewriter pyRewriter(rewriter);
-      nb::object listener = nb::cast(pyRewriter.getListener());
-
-      listener.attr("__enter__")();
-      auto exit = llvm::make_scope_exit([listener] {
-        listener.attr("__exit__")(nb::none(), nb::none(), nb::none());
-      });
-      nb::object res = f(opView, pyRewriter);
+      nb::object listener = userData_->listener;
+      nb::object res = f(opView, pyRewriter, listener);
 
       return logicalResultFromObject(res);
     };
+
+    UserData *userData_ = new UserData{matchAndRewriteCb, listener};
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(
-        rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
+        rootName, benefit, ctx, callbacks, static_cast<void *>(userData_),
         /* nGeneratedNames */ 0,
         /* generatedNames */ nullptr);
     mlirRewritePatternSetAdd(set, pattern);
@@ -291,13 +298,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
       .def(
           "add",
           [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
-             unsigned benefit) {
+             const nb::object &listener, unsigned benefit) {
             std::string opName =
                 nb::cast<std::string>(root.attr("OPERATION_NAME"));
             self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
-                     fn);
+                     fn, listener);
           },
-          "root"_a, "fn"_a, "benefit"_a = 1,
+          "root"_a, "fn"_a, "listener"_a, "benefit"_a = 1,
           "Add a new rewrite pattern on the given root operation with the "
           "callable as the matching and rewriting function and the given "
           "benefit.")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your patch! I understand the code, but since we can already get the listener via
nb::object listener = nb::cast(pyRewriter.getListener())
without passing additional user data, is there any difference?

Copy link
Member Author

@PragmaTwice PragmaTwice Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I think I've got some clue about the inconsistency between our understanding.

I want to explain in a gradual way why we need a global state (or stack). Firstly, we can just throw these stack/state things if we write code like this:

def match_and_rewrite(op, rewriter):
  new_op = arith.add(op.lhs, op.rhs, listener=rewriter.listener) 
  # or just: new_op = arith.add(op.lhs, op.rhs, rewriter=rewriter)
  ...

Here we pass the listener (or the rewriter) into the op constructor so we don't need any global state, and then we can use this listener to call notify.. methods. But in the initial I want to avoid such passing, so some code like this:

def match_and_rewrite(op, rewriter):
  new_op = arith.add(op.lhs, op.rhs) # no listener/rewriter passing, so we need to retrieve it from some global state!
  ...

And then, obviously, we need some global state to know which is the current listener that we need to obtain in the op constructor to call notify methods.

If we want to avoid global state/stack maybe we can just pass the listener to op constructors? this may require some changes to tblgen.

nb::object res = f(opView, pyRewriter);

return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePattenCreate(
Expand Down Expand Up @@ -234,6 +247,8 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.")
.def_prop_ro("listener", &PyPatternRewriter::getListener,
"The rewrite listener of the PatternRewriter.")
.def(
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/CAPI/Transforms/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
return wrap(unwrap(rewriter)->getContext());
}

MlirRewriterBaseListener
mlirRewriterBaseGetListener(MlirRewriterBase rewriter) {
return wrap(
dyn_cast<RewriterBase::Listener>(unwrap(rewriter)->getListener()));
}

void mlirRewriterBaseListenerNotifyOperationInserted(
MlirRewriterBaseListener listener, MlirOperation op,
MlirOperation insertionPoint) {
OpBuilder::InsertPoint ip;
if (!mlirOperationIsNull(insertionPoint)) {
ip = OpBuilder::InsertPoint(unwrap(insertionPoint)->getBlock(),
Block::iterator(unwrap(insertionPoint)));
}
return unwrap(listener)->notifyOperationInserted(unwrap(op), ip);
}

//===----------------------------------------------------------------------===//
/// Insertion points methods
//===----------------------------------------------------------------------===//
Expand Down