99#include " Pass.h"
1010
1111#include " IRModule.h"
12+ #include " mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1213#include " mlir-c/Pass.h"
1314#include " mlir/Bindings/Python/Nanobind.h"
14- #include " mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15+ #include " nanobind/trampoline.h"
16+ #include " llvm/Support/ErrorHandling.h"
1517
1618namespace nb = nanobind;
1719using namespace nb ::literals;
@@ -20,6 +22,63 @@ using namespace mlir::python;
2022
2123namespace {
2224
25+ // A base class for defining passes in Python
26+ // Users are expected to subclass this and implement the `run` method, e.g.
27+ // ```
28+ // class MyPass(mlir.passmanager.Pass):
29+ // def run(self, operation):
30+ // # do something with operation
31+ // pass
32+ // ```
33+ class PyPassBase {
34+ public:
35+ PyPassBase () : callbacks{} {
36+ callbacks.construct = [](void *) {};
37+ callbacks.destruct = [](void *) {};
38+ callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
39+ static_cast <PyPassBase *>(obj)->run (op);
40+ };
41+ // TODO: currently we don't support pass cloning in python
42+ // due to lifetime management issues.
43+ callbacks.clone = [](void *obj) -> void * {
44+ // since the caller here should be MLIR C++ code,
45+ // we need to avoid using exceptions like throw py::value_error(...).
46+ llvm_unreachable (" cloning of python-defined passes is not supported" );
47+ };
48+ }
49+
50+ // this method should be overridden by subclasses in Python.
51+ virtual void run (MlirOperation op) = 0;
52+
53+ virtual ~PyPassBase () = default ;
54+
55+ // Make an MlirPass instance on-the-fly that wraps this object.
56+ // Note that passmanager will take the ownership of the returned
57+ // object and release it when appropriate.
58+ // Also, `*this` must remain alive as long as the returned object is alive.
59+ MlirPass make () {
60+ return mlirCreateExternalPass (
61+ mlirTypeIDCreate (this ),
62+ mlirStringRefCreateFromCString (" python-example-pass" ),
63+ mlirStringRefCreateFromCString (" " ),
64+ mlirStringRefCreateFromCString (" Python Example Pass" ),
65+ mlirStringRefCreateFromCString (" " ), 0 , nullptr , callbacks, this );
66+ }
67+
68+ private:
69+ MlirExternalPassCallbacks callbacks;
70+ };
71+
72+ // A trampoline class upon PyPassBase.
73+ // Refer to
74+ // https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
75+ class PyPass : PyPassBase {
76+ public:
77+ NB_TRAMPOLINE (PyPassBase, 1 );
78+
79+ void run (MlirOperation op) override { NB_OVERRIDE_PURE (run, op); }
80+ };
81+
2382// / Owning Wrapper around a PassManager.
2483class PyPassManager {
2584public:
@@ -52,6 +111,16 @@ class PyPassManager {
52111
53112} // namespace
54113
114+ void mlir::python::populatePassSubmodule (nanobind::module_ &m) {
115+ // ----------------------------------------------------------------------------
116+ // Mapping of the Python-defined Pass interface
117+ // ----------------------------------------------------------------------------
118+ nb::class_<PyPassBase, PyPass>(m, " Pass" )
119+ .def (nb::init<>(), " Create a new Pass." )
120+ .def (" run" , &PyPassBase::run, " operation" _a,
121+ " Run the pass on the provided operation." );
122+ }
123+
55124// / Create the `mlir.passmanager` here.
56125void mlir::python::populatePassManagerSubmodule (nb::module_ &m) {
57126 // ----------------------------------------------------------------------------
@@ -157,6 +226,15 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
157226 " pipeline" _a,
158227 " Add textual pipeline elements to the pass manager. Throws a "
159228 " ValueError if the pipeline can't be parsed." )
229+ .def (
230+ " add" ,
231+ [](PyPassManager &passManager, PyPassBase &pass) {
232+ mlirPassManagerAddOwnedPass (passManager.get (), pass.make ());
233+ },
234+ " pass" _a, " Add a python-defined pass to the pass manager." ,
235+ // NOTE that we should keep the pass object alive as long as the
236+ // passManager to prevent dangling objects.
237+ nb::keep_alive<1 , 2 >())
160238 .def (
161239 " run" ,
162240 [](PyPassManager &passManager, PyOperationBase &op,
0 commit comments