- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
          [MLIR][Python] restore liveModuleMap
          #158506
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
  
    [MLIR][Python] restore liveModuleMap
  
  #158506
              Conversation
4886e7a    to
    07e4202      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We were hitting lifetime issues without.
| @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesThere are cases where the same module can have multiple references (via  Full diff: https://github.com/llvm/llvm-project/pull/158506.diff 3 Files Affected: 
 diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8273a9346e5dd..10360e448858c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
     : BaseContextObject(std::move(contextRef)), module(module) {}
 
-PyModule::~PyModule() { mlirModuleDestroy(module); }
+PyModule::~PyModule() {
+  nb::gil_scoped_acquire acquire;
+  auto &liveModules = getContext()->liveModules;
+  assert(liveModules.count(module.ptr) == 1 &&
+         "destroying module not in live map");
+  liveModules.erase(module.ptr);
+  mlirModuleDestroy(module);
+}
 
 PyModuleRef PyModule::forModule(MlirModule module) {
   MlirContext context = mlirModuleGetContext(module);
   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
 
-  // Create.
-  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-  // Note that the default return value policy on cast is `automatic_reference`,
-  // which means "does not take ownership, does not call delete/dtor".
-  // We use `take_ownership`, which means "Python will call the C++ destructor
-  // and delete operator when the Python wrapper is garbage collected", because
-  // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
-  // etc).
-  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
-  unownedModule->handle = pyRef;
-  return PyModuleRef(unownedModule, std::move(pyRef));
+  nb::gil_scoped_acquire acquire;
+  auto &liveModules = contextRef->liveModules;
+  auto it = liveModules.find(module.ptr);
+  if (it == liveModules.end()) {
+    // Create.
+    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+    // Note that the default return value policy on cast is automatic_reference,
+    // which does not take ownership (delete will not be called).
+    // Just be explicit.
+    nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+    unownedModule->handle = pyRef;
+    liveModules[module.ptr] =
+        std::make_pair(unownedModule->handle, unownedModule);
+    return PyModuleRef(unownedModule, std::move(pyRef));
+  }
+  // Use existing.
+  PyModule *existing = it->second.second;
+  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
+  return PyModuleRef(existing, std::move(pyRef));
 }
 
 nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
   return PyInsertionPoint{block, std::move(nextOpRef)};
 }
 
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
 }
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
+      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1d1ff29533f98..28b885f136fe0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,6 +218,10 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
+  /// Gets the count of live modules associated with this context.
+  /// Used for testing.
+  size_t getLiveModuleCount();
+
   /// Enter and exit the context manager.
   static nanobind::object contextEnter(nanobind::object context);
   void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ class PyMlirContext {
   static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
+  // Interns all live modules associated with this context. Modules tracked
+  // in this map are valid. When a module is invalidated, it is removed
+  // from this map, and while it still exists as an instance, any
+  // attempt to access it will raise an error.
+  using LiveModuleMap =
+      llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
+  LiveModuleMap liveModules;
+
   bool emitErrorDiagnostics = false;
 
   MlirContext context;
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ad4c9340a6c82..175f26007ec6a 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,6 +121,7 @@ def testRoundtripBinary():
 def testModuleOperation():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
     op1 = module.operation
     # CHECK: module @successfulParse
     print(op1)
@@ -145,6 +146,7 @@ def testModuleOperation():
     op1 = None
     op2 = None
     gc.collect()
+    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
@@ -152,17 +154,16 @@ def testModuleOperation():
 def testModuleCapsule():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
     # CHECK: "mlir.ir.Module._CAPIPtr"
     module_capsule = module._CAPIPtr
     print(module_capsule)
     module_dup = Module._CAPICreate(module_capsule)
-    assert module is not module_dup
-    assert module == module_dup
-    module._clear_mlir_module()
-    assert module != module_dup
+    assert module is module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None
     module_capsule = None
     module_dup = None
     gc.collect()
+    assert ctx._get_live_module_count() == 0
 | 
liveModuleMap
      07e4202    to
    9698a5a      
    Compare
  
    | Noting that this was a fix for an issue introduced by #155114. | 
There are cases where the same module can have multiple references (via
PyModule::forModuleviaPyModule::createFromCapsule) and thus whenPyModules get gc'dmlirModuleDestroycan get called multiple times for the same actual underlyingmlir::Module(i.e., double free). So we do actually need a "liveness map" for modules. Note, iftype_caster<MlirModule>::from_cppweren't a thing we could guarantree this never happened except explicitly when users calledPyModule::createFromCapsule.