1414#include " mlir/Bindings/Python/Nanobind.h"
1515#include " mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1616// clang-format on
17- #include " nanobind/trampoline.h"
1817
1918namespace nb = nanobind;
2019using namespace nb ::literals;
@@ -23,81 +22,6 @@ using namespace mlir::python;
2322
2423namespace {
2524
26- // A base class for defining passes in Python
27- // Users are expected to subclass this and implement the `run` method, e.g.
28- // ```
29- // class MyPass(Pass):
30- // def __init__(self):
31- // super().__init__("MyPass", ..)
32- // # other init stuff..
33- // def run(self, operation):
34- // # do something with operation..
35- // pass
36- // ```
37- class PyPassBase {
38- public:
39- PyPassBase (std::string name, std::string argument, std::string description,
40- std::string opName)
41- : name(std::move(name)), argument(std::move(argument)),
42- description (std::move(description)), opName(std::move(opName)) {
43- callbacks.construct = [](void *obj) {};
44- callbacks.destruct = [](void *obj) {
45- nb::handle (static_cast <PyObject *>(obj)).dec_ref ();
46- };
47- callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
48- auto handle = nb::handle (static_cast <PyObject *>(obj));
49- nb::cast<PyPassBase *>(handle)->run (op);
50- };
51- callbacks.clone = [](void *obj) -> void * {
52- nb::object copy = nb::module_::import_ (" copy" );
53- nb::object deepcopy = copy.attr (" deepcopy" );
54- return deepcopy (obj).release ().ptr ();
55- };
56- callbacks.initialize = nullptr ;
57- }
58-
59- // this method should be overridden by subclasses in Python.
60- virtual void run (MlirOperation op) = 0;
61-
62- virtual ~PyPassBase () = default ;
63-
64- // Make an MlirPass instance on-the-fly that wraps this object.
65- // Note that passmanager will take the ownership of the returned
66- // object and release it when appropriate.
67- MlirPass make () {
68- auto *obj = nb::find (this ).release ().ptr ();
69- return mlirCreateExternalPass (
70- mlirTypeIDCreate (this ), mlirStringRefCreate (name.data (), name.length ()),
71- mlirStringRefCreate (argument.data (), argument.length ()),
72- mlirStringRefCreate (description.data (), description.length ()),
73- mlirStringRefCreate (opName.data (), opName.size ()), 0 , nullptr ,
74- callbacks, obj);
75- }
76-
77- const std::string &getName () const { return name; }
78- const std::string &getArgument () const { return argument; }
79- const std::string &getDescription () const { return description; }
80- const std::string &getOpName () const { return opName; }
81-
82- private:
83- MlirExternalPassCallbacks callbacks;
84-
85- std::string name;
86- std::string argument;
87- std::string description;
88- std::string opName;
89- };
90-
91- // A trampoline class upon PyPassBase.
92- // Refer to
93- // https://nanobind.readthedocs.io/en/latest/classes.html#overriding-virtual-functions-in-python
94- class PyPass : PyPassBase {
95- public:
96- NB_TRAMPOLINE (PyPassBase, 1 );
97-
98- void run (MlirOperation op) override { NB_OVERRIDE_PURE (run, op); }
99- };
100-
10125// / Owning Wrapper around a PassManager.
10226class PyPassManager {
10327public:
@@ -130,26 +54,6 @@ class PyPassManager {
13054
13155} // namespace
13256
133- void mlir::python::populatePassSubmodule (nanobind::module_ &m) {
134- // ----------------------------------------------------------------------------
135- // Mapping of the Python-defined Pass interface
136- // ----------------------------------------------------------------------------
137- nb::class_<PyPassBase, PyPass>(m, " Pass" )
138- .def (nb::init<std::string, std::string, std::string, std::string>(),
139- " name" _a, nb::kw_only (), " argument" _a = " " , " description" _a = " " ,
140- " op_name" _a = " " , " Create a new Pass." )
141- .def (" run" , &PyPassBase::run, " operation" _a,
142- " Run the pass on the provided operation." )
143- .def_prop_ro (" name" ,
144- [](const PyPassBase &self) { return self.getName (); })
145- .def_prop_ro (" argument" ,
146- [](const PyPassBase &self) { return self.getArgument (); })
147- .def_prop_ro (" description" ,
148- [](const PyPassBase &self) { return self.getDescription (); })
149- .def_prop_ro (" op_name" ,
150- [](const PyPassBase &self) { return self.getOpName (); });
151- }
152-
15357// / Create the `mlir.passmanager` here.
15458void mlir::python::populatePassManagerSubmodule (nb::module_ &m) {
15559 // ----------------------------------------------------------------------------
@@ -256,11 +160,44 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
256160 " Add textual pipeline elements to the pass manager. Throws a "
257161 " ValueError if the pipeline can't be parsed." )
258162 .def (
259- " add" ,
260- [](PyPassManager &passManager, PyPassBase &pass) {
261- mlirPassManagerAddOwnedPass (passManager.get (), pass.make ());
163+ " add_python_pass" ,
164+ [](PyPassManager &passManager, const nb::callable &run,
165+ std::optional<std::string> &name, const std::string &argument,
166+ const std::string &description, const std::string &opName) {
167+ if (!name.has_value ()) {
168+ name = nb::cast<std::string>(
169+ nb::borrow<nb::str>(run.attr (" __name__" )));
170+ }
171+ MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate ();
172+ MlirTypeID passID =
173+ mlirTypeIDAllocatorAllocateTypeID (typeIDAllocator);
174+ MlirExternalPassCallbacks callbacks;
175+ callbacks.construct = [](void *obj) {
176+ (void )nb::handle (static_cast <PyObject *>(obj)).inc_ref ();
177+ };
178+ callbacks.destruct = [](void *obj) {
179+ (void )nb::handle (static_cast <PyObject *>(obj)).dec_ref ();
180+ };
181+ callbacks.initialize = nullptr ;
182+ callbacks.clone = [](void *) -> void * {
183+ throw std::runtime_error (" Cloning Python passes not supported" );
184+ };
185+ callbacks.run = [](MlirOperation op, MlirExternalPass,
186+ void *userData) {
187+ nb::borrow<nb::callable>(static_cast <PyObject *>(userData))(op);
188+ };
189+ auto externalPass = mlirCreateExternalPass (
190+ passID, mlirStringRefCreate (name->data (), name->length ()),
191+ mlirStringRefCreate (argument.data (), argument.length ()),
192+ mlirStringRefCreate (description.data (), description.length ()),
193+ mlirStringRefCreate (opName.data (), opName.size ()),
194+ /* nDependentDialects*/ 0 , /* dependentDialects*/ nullptr ,
195+ callbacks, /* userData*/ run.ptr ());
196+ mlirPassManagerAddOwnedPass (passManager.get (), externalPass);
262197 },
263- " pass" _a, " Add a python-defined pass to the pass manager." )
198+ " run" _a, " name" _a.none () = nb::none (), " argument" _a.none () = " " ,
199+ " description" _a.none () = " " , " op_name" _a.none () = " " ,
200+ " Add a python-defined pass to the pass manager." )
264201 .def (
265202 " run" ,
266203 [](PyPassManager &passManager, PyOperationBase &op,
0 commit comments