-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] restore liveModuleMap
#158506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Python] restore liveModuleMap
#158506
Conversation
4886e7a
to
07e4202
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We were hitting lifetime issues without.
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesThere are cases where the same module can have multiple references (via Full diff: https://github.com/llvm/llvm-project/pull/158506.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8273a9346e5dd..10360e448858c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
-PyModule::~PyModule() { mlirModuleDestroy(module); }
+PyModule::~PyModule() {
+ nb::gil_scoped_acquire acquire;
+ auto &liveModules = getContext()->liveModules;
+ assert(liveModules.count(module.ptr) == 1 &&
+ "destroying module not in live map");
+ liveModules.erase(module.ptr);
+ mlirModuleDestroy(module);
+}
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- // Create.
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is `automatic_reference`,
- // which means "does not take ownership, does not call delete/dtor".
- // We use `take_ownership`, which means "Python will call the C++ destructor
- // and delete operator when the Python wrapper is garbage collected", because
- // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
- // etc).
- nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
- unownedModule->handle = pyRef;
- return PyModuleRef(unownedModule, std::move(pyRef));
+ nb::gil_scoped_acquire acquire;
+ auto &liveModules = contextRef->liveModules;
+ auto it = liveModules.find(module.ptr);
+ if (it == liveModules.end()) {
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ liveModules[module.ptr] =
+ std::make_pair(unownedModule->handle, unownedModule);
+ return PyModuleRef(unownedModule, std::move(pyRef));
+ }
+ // Use existing.
+ PyModule *existing = it->second.second;
+ nb::object pyRef = nb::borrow<nb::object>(it->second.first);
+ return PyModuleRef(existing, std::move(pyRef));
}
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
return PyInsertionPoint{block, std::move(nextOpRef)};
}
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1d1ff29533f98..28b885f136fe0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,6 +218,10 @@ class PyMlirContext {
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
+ /// Gets the count of live modules associated with this context.
+ /// Used for testing.
+ size_t getLiveModuleCount();
+
/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ class PyMlirContext {
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
+ // Interns all live modules associated with this context. Modules tracked
+ // in this map are valid. When a module is invalidated, it is removed
+ // from this map, and while it still exists as an instance, any
+ // attempt to access it will raise an error.
+ using LiveModuleMap =
+ llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
+ LiveModuleMap liveModules;
+
bool emitErrorDiagnostics = false;
MlirContext context;
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ad4c9340a6c82..175f26007ec6a 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,6 +121,7 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert ctx._get_live_module_count() == 1
op1 = module.operation
# CHECK: module @successfulParse
print(op1)
@@ -145,6 +146,7 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
+ assert ctx._get_live_module_count() == 0
# CHECK-LABEL: TEST: testModuleCapsule
@@ -152,17 +154,16 @@ def testModuleOperation():
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
- assert module is not module_dup
- assert module == module_dup
- module._clear_mlir_module()
- assert module != module_dup
+ assert module is module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
+ assert ctx._get_live_module_count() == 0
|
liveModuleMap
07e4202
to
9698a5a
Compare
Noting that this was a fix for an issue introduced by #155114. |
There are cases where the same module can have multiple references (via
PyModule::forModule
viaPyModule::createFromCapsule
) and thus whenPyModule
s get gc'dmlirModuleDestroy
can get called multiple times for the same actual underlyingmlir::Module
(i.e., double free). So we do actually need a "liveness map" for modules. Note, iftype_caster<MlirModule>::from_cpp
weren't a thing we could guarantree this never happened except explicitly when users calledPyModule::createFromCapsule
.