@@ -35,20 +35,22 @@ class PyPassBase {
3535public:
3636 PyPassBase (std::string name, std::string argument, std::string description,
3737 std::string opName)
38- : callbacks{}, name(std::move(name)), argument(std::move(argument)),
38+ : name(std::move(name)), argument(std::move(argument)),
3939 description (std::move(description)), opName(std::move(opName)) {
40- callbacks.construct = [](void *) {};
41- callbacks.destruct = [](void *) {};
40+ callbacks.construct = [](void *obj) {};
41+ callbacks.destruct = [](void *obj) {
42+ nb::handle (static_cast <PyObject *>(obj)).dec_ref ();
43+ };
4244 callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
43- static_cast <PyPassBase *>(obj)->run (op);
45+ auto handle = nb::handle (static_cast <PyObject *>(obj));
46+ nb::cast<PyPassBase *>(handle)->run (op);
4447 };
45- // TODO: currently we don't support pass cloning in python
46- // due to lifetime management issues.
4748 callbacks.clone = [](void *obj) -> void * {
48- // since the caller here should be MLIR C++ code,
49- // we need to avoid using exceptions like throw py::value_error(...).
50- llvm_unreachable ( " cloning of python-defined passes is not supported " );
49+ nb::object copy = nb::module_::import_ ( " copy " );
50+ nb::object deepcopy = copy. attr ( " deepcopy " );
51+ return deepcopy (obj). release (). ptr ( );
5152 };
53+ callbacks.initialize = nullptr ;
5254 }
5355
5456 // this method should be overridden by subclasses in Python.
@@ -61,12 +63,13 @@ class PyPassBase {
6163 // object and release it when appropriate.
6264 // Also, `*this` must remain alive as long as the returned object is alive.
6365 MlirPass make () {
66+ auto *obj = nb::find (this ).release ().ptr ();
6467 return mlirCreateExternalPass (
6568 mlirTypeIDCreate (this ), mlirStringRefCreate (name.data (), name.length ()),
6669 mlirStringRefCreate (argument.data (), argument.length ()),
6770 mlirStringRefCreate (description.data (), description.length ()),
6871 mlirStringRefCreate (opName.data (), opName.size ()), 0 , nullptr ,
69- callbacks, this );
72+ callbacks, obj );
7073 }
7174
7275 const std::string &getName () const { return name; }
@@ -255,10 +258,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
255258 [](PyPassManager &passManager, PyPassBase &pass) {
256259 mlirPassManagerAddOwnedPass (passManager.get (), pass.make ());
257260 },
258- " pass" _a, " Add a python-defined pass to the pass manager." ,
259- // NOTE that we should keep the pass object alive as long as the
260- // passManager to prevent dangling objects.
261- nb::keep_alive<1 , 2 >())
261+ " pass" _a, " Add a python-defined pass to the pass manager." )
262262 .def (
263263 " run" ,
264264 [](PyPassManager &passManager, PyOperationBase &op,
0 commit comments