Skip to content

Commit 8386c87

Browse files
committed
[MLIR][Python] Support Python-defined passes in MLIR
1 parent 0a99348 commit 8386c87

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,5 @@ NB_MODULE(_mlir, m) {
139139
auto passModule =
140140
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
141141
populatePassManagerSubmodule(passModule);
142+
populatePassSubmodule(passModule);
142143
}

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include "Pass.h"
1010

1111
#include "IRModule.h"
12+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1213
#include "mlir-c/Pass.h"
1314
#include "mlir/Bindings/Python/Nanobind.h"
14-
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15+
#include "nanobind/trampoline.h"
16+
#include "llvm/Support/ErrorHandling.h"
1517

1618
namespace nb = nanobind;
1719
using namespace nb::literals;
@@ -20,6 +22,63 @@ using namespace mlir::python;
2022

2123
namespace {
2224

25+
// A base class for defining passes in Python
26+
// Users are expected to subclass this and implement the `run` method, e.g.
27+
// ```
28+
// class MyPass(mlir.passmanager.Pass):
29+
// def run(self, operation):
30+
// # do something with operation
31+
// pass
32+
// ```
33+
class PyPassBase {
34+
public:
35+
PyPassBase() : callbacks{} {
36+
callbacks.construct = [](void *) {};
37+
callbacks.destruct = [](void *) {};
38+
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
39+
static_cast<PyPassBase *>(obj)->run(op);
40+
};
41+
// TODO: currently we don't support pass cloning in python
42+
// due to lifetime management issues.
43+
callbacks.clone = [](void *obj) -> void * {
44+
// since the caller here should be MLIR C++ code,
45+
// we need to avoid using exceptions like throw py::value_error(...).
46+
llvm_unreachable("cloning of python-defined passes is not supported");
47+
};
48+
}
49+
50+
// this method should be overridden by subclasses in Python.
51+
virtual void run(MlirOperation op) = 0;
52+
53+
virtual ~PyPassBase() = default;
54+
55+
// Make an MlirPass instance on-the-fly that wraps this object.
56+
// Note that passmanager will take the ownership of the returned
57+
// object and release it when appropriate.
58+
// Also, `*this` must remain alive as long as the returned object is alive.
59+
MlirPass make() {
60+
return mlirCreateExternalPass(
61+
mlirTypeIDCreate(this),
62+
mlirStringRefCreateFromCString("python-example-pass"),
63+
mlirStringRefCreateFromCString(""),
64+
mlirStringRefCreateFromCString("Python Example Pass"),
65+
mlirStringRefCreateFromCString(""), 0, nullptr, callbacks, this);
66+
}
67+
68+
private:
69+
MlirExternalPassCallbacks callbacks;
70+
};
71+
72+
// A trampoline class upon PyPassBase.
73+
// Refer to
74+
// https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
75+
class PyPass : PyPassBase {
76+
public:
77+
NB_TRAMPOLINE(PyPassBase, 1);
78+
79+
void run(MlirOperation op) override { NB_OVERRIDE_PURE(run, op); }
80+
};
81+
2382
/// Owning Wrapper around a PassManager.
2483
class PyPassManager {
2584
public:
@@ -52,6 +111,16 @@ class PyPassManager {
52111

53112
} // namespace
54113

114+
void mlir::python::populatePassSubmodule(nanobind::module_ &m) {
115+
//----------------------------------------------------------------------------
116+
// Mapping of the Python-defined Pass interface
117+
//----------------------------------------------------------------------------
118+
nb::class_<PyPassBase, PyPass>(m, "Pass")
119+
.def(nb::init<>(), "Create a new Pass.")
120+
.def("run", &PyPassBase::run, "operation"_a,
121+
"Run the pass on the provided operation.");
122+
}
123+
55124
/// Create the `mlir.passmanager` here.
56125
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
57126
//----------------------------------------------------------------------------
@@ -157,6 +226,15 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
157226
"pipeline"_a,
158227
"Add textual pipeline elements to the pass manager. Throws a "
159228
"ValueError if the pipeline can't be parsed.")
229+
.def(
230+
"add",
231+
[](PyPassManager &passManager, PyPassBase &pass) {
232+
mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
233+
},
234+
"pass"_a, "Add a python-defined pass to the pass manager.",
235+
// NOTE that we should keep the pass object alive as long as the
236+
// passManager to prevent dangling objects.
237+
nb::keep_alive<1, 2>())
160238
.def(
161239
"run",
162240
[](PyPassManager &passManager, PyOperationBase &op,

mlir/lib/Bindings/Python/Pass.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace mlir {
1515
namespace python {
1616

1717
void populatePassManagerSubmodule(nanobind::module_ &m);
18+
void populatePassSubmodule(nanobind::module_ &m);
1819

1920
} // namespace python
2021
} // namespace mlir

0 commit comments

Comments
 (0)