Skip to content

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Sep 14, 2025

There are cases where the same module can have multiple references (via PyModule::forModule via PyModule::createFromCapsule) and thus when PyModules get gc'd mlirModuleDestroy can get called multiple times for the same actual underlying mlir::Module (i.e., double free). So we do actually need a "liveness map" for modules. Note, if type_caster<MlirModule>::from_cpp weren't a thing we could guarantree this never happened except explicitly when users called PyModule::createFromCapsule.

@makslevental makslevental force-pushed the users/makslevental/restore-livemodules branch 4 times, most recently from 4886e7a to 07e4202 Compare September 14, 2025 19:24
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We were hitting lifetime issues without.

@makslevental makslevental marked this pull request as ready for review September 14, 2025 19:32
@llvmbot llvmbot added the mlir label Sep 14, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2025

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

There are cases where the same module can have multiple references (via PyModule::forModule via PyModule::createFromCapsule) and thus when PyModules get gc'd mlirModuleDestroy can get called multiple times for the same actual underlying mlir::Module (i.e., double free). So we do actually need a "liveness map" for modules. Note, if type_caster&lt;MlirModule&gt;::from_cpp weren't a thing we could guarantree this never happened except explicitly when users called PyModule::createFromCapsule.


Full diff: https://github.com/llvm/llvm-project/pull/158506.diff

3 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+30-12)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+12)
  • (modified) mlir/test/python/ir/module.py (+5-4)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8273a9346e5dd..10360e448858c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
     : BaseContextObject(std::move(contextRef)), module(module) {}
 
-PyModule::~PyModule() { mlirModuleDestroy(module); }
+PyModule::~PyModule() {
+  nb::gil_scoped_acquire acquire;
+  auto &liveModules = getContext()->liveModules;
+  assert(liveModules.count(module.ptr) == 1 &&
+         "destroying module not in live map");
+  liveModules.erase(module.ptr);
+  mlirModuleDestroy(module);
+}
 
 PyModuleRef PyModule::forModule(MlirModule module) {
   MlirContext context = mlirModuleGetContext(module);
   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
 
-  // Create.
-  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-  // Note that the default return value policy on cast is `automatic_reference`,
-  // which means "does not take ownership, does not call delete/dtor".
-  // We use `take_ownership`, which means "Python will call the C++ destructor
-  // and delete operator when the Python wrapper is garbage collected", because
-  // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
-  // etc).
-  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
-  unownedModule->handle = pyRef;
-  return PyModuleRef(unownedModule, std::move(pyRef));
+  nb::gil_scoped_acquire acquire;
+  auto &liveModules = contextRef->liveModules;
+  auto it = liveModules.find(module.ptr);
+  if (it == liveModules.end()) {
+    // Create.
+    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+    // Note that the default return value policy on cast is automatic_reference,
+    // which does not take ownership (delete will not be called).
+    // Just be explicit.
+    nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+    unownedModule->handle = pyRef;
+    liveModules[module.ptr] =
+        std::make_pair(unownedModule->handle, unownedModule);
+    return PyModuleRef(unownedModule, std::move(pyRef));
+  }
+  // Use existing.
+  PyModule *existing = it->second.second;
+  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
+  return PyModuleRef(existing, std::move(pyRef));
 }
 
 nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
   return PyInsertionPoint{block, std::move(nextOpRef)};
 }
 
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
 }
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
+      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1d1ff29533f98..28b885f136fe0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,6 +218,10 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
+  /// Gets the count of live modules associated with this context.
+  /// Used for testing.
+  size_t getLiveModuleCount();
+
   /// Enter and exit the context manager.
   static nanobind::object contextEnter(nanobind::object context);
   void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ class PyMlirContext {
   static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
+  // Interns all live modules associated with this context. Modules tracked
+  // in this map are valid. When a module is invalidated, it is removed
+  // from this map, and while it still exists as an instance, any
+  // attempt to access it will raise an error.
+  using LiveModuleMap =
+      llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
+  LiveModuleMap liveModules;
+
   bool emitErrorDiagnostics = false;
 
   MlirContext context;
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ad4c9340a6c82..175f26007ec6a 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,6 +121,7 @@ def testRoundtripBinary():
 def testModuleOperation():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
     op1 = module.operation
     # CHECK: module @successfulParse
     print(op1)
@@ -145,6 +146,7 @@ def testModuleOperation():
     op1 = None
     op2 = None
     gc.collect()
+    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
@@ -152,17 +154,16 @@ def testModuleOperation():
 def testModuleCapsule():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
     # CHECK: "mlir.ir.Module._CAPIPtr"
     module_capsule = module._CAPIPtr
     print(module_capsule)
     module_dup = Module._CAPICreate(module_capsule)
-    assert module is not module_dup
-    assert module == module_dup
-    module._clear_mlir_module()
-    assert module != module_dup
+    assert module is module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None
     module_capsule = None
     module_dup = None
     gc.collect()
+    assert ctx._get_live_module_count() == 0

@makslevental makslevental changed the title [MLIR][Python] restore liveModuleMap [MLIR][Python] restore liveModuleMap Sep 14, 2025
@makslevental makslevental force-pushed the users/makslevental/restore-livemodules branch from 07e4202 to 9698a5a Compare September 14, 2025 23:41
@jpienaar jpienaar merged commit 6a4f664 into llvm:main Sep 15, 2025
11 of 12 checks passed
@makslevental makslevental deleted the users/makslevental/restore-livemodules branch September 15, 2025 05:07
@jyknight
Copy link
Member

Noting that this was a fix for an issue introduced by #155114.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants