Skip to content

Commit 7556ca2

Browse files
committed
fix lifetime issue
1 parent cb82621 commit 7556ca2

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,22 @@ class PyPassBase {
3535
public:
3636
PyPassBase(std::string name, std::string argument, std::string description,
3737
std::string opName)
38-
: callbacks{}, name(std::move(name)), argument(std::move(argument)),
38+
: name(std::move(name)), argument(std::move(argument)),
3939
description(std::move(description)), opName(std::move(opName)) {
40-
callbacks.construct = [](void *) {};
41-
callbacks.destruct = [](void *) {};
40+
callbacks.construct = [](void *obj) {};
41+
callbacks.destruct = [](void *obj) {
42+
nb::handle(static_cast<PyObject *>(obj)).dec_ref();
43+
};
4244
callbacks.run = [](MlirOperation op, MlirExternalPass, void *obj) {
43-
static_cast<PyPassBase *>(obj)->run(op);
45+
auto handle = nb::handle(static_cast<PyObject *>(obj));
46+
nb::cast<PyPassBase *>(handle)->run(op);
4447
};
45-
// TODO: currently we don't support pass cloning in python
46-
// due to lifetime management issues.
4748
callbacks.clone = [](void *obj) -> void * {
48-
// since the caller here should be MLIR C++ code,
49-
// we need to avoid using exceptions like throw py::value_error(...).
50-
llvm_unreachable("cloning of python-defined passes is not supported");
49+
nb::object copy = nb::module_::import_("copy");
50+
nb::object deepcopy = copy.attr("deepcopy");
51+
return deepcopy(obj).release().ptr();
5152
};
53+
callbacks.initialize = nullptr;
5254
}
5355

5456
// this method should be overridden by subclasses in Python.
@@ -61,12 +63,13 @@ class PyPassBase {
6163
// object and release it when appropriate.
6264
// Also, `*this` must remain alive as long as the returned object is alive.
6365
MlirPass make() {
66+
auto *obj = nb::find(this).release().ptr();
6467
return mlirCreateExternalPass(
6568
mlirTypeIDCreate(this), mlirStringRefCreate(name.data(), name.length()),
6669
mlirStringRefCreate(argument.data(), argument.length()),
6770
mlirStringRefCreate(description.data(), description.length()),
6871
mlirStringRefCreate(opName.data(), opName.size()), 0, nullptr,
69-
callbacks, this);
72+
callbacks, obj);
7073
}
7174

7275
const std::string &getName() const { return name; }
@@ -255,10 +258,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
255258
[](PyPassManager &passManager, PyPassBase &pass) {
256259
mlirPassManagerAddOwnedPass(passManager.get(), pass.make());
257260
},
258-
"pass"_a, "Add a python-defined pass to the pass manager.",
259-
// NOTE that we should keep the pass object alive as long as the
260-
// passManager to prevent dangling objects.
261-
nb::keep_alive<1, 2>())
261+
"pass"_a, "Add a python-defined pass to the pass manager.")
262262
.def(
263263
"run",
264264
[](PyPassManager &passManager, PyOperationBase &op,

0 commit comments

Comments
 (0)