@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
1079
1079
PyModule::PyModule (PyMlirContextRef contextRef, MlirModule module )
1080
1080
: BaseContextObject(std::move(contextRef)), module (module ) {}
1081
1081
1082
- PyModule::~PyModule () { mlirModuleDestroy (module ); }
1082
+ PyModule::~PyModule () {
1083
+ nb::gil_scoped_acquire acquire;
1084
+ auto &liveModules = getContext ()->liveModules ;
1085
+ assert (liveModules.count (module .ptr ) == 1 &&
1086
+ " destroying module not in live map" );
1087
+ liveModules.erase (module .ptr );
1088
+ mlirModuleDestroy (module );
1089
+ }
1083
1090
1084
1091
PyModuleRef PyModule::forModule (MlirModule module ) {
1085
1092
MlirContext context = mlirModuleGetContext (module );
1086
1093
PyMlirContextRef contextRef = PyMlirContext::forContext (context);
1087
1094
1088
- // Create.
1089
- PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1090
- // Note that the default return value policy on cast is `automatic_reference`,
1091
- // which means "does not take ownership, does not call delete/dtor".
1092
- // We use `take_ownership`, which means "Python will call the C++ destructor
1093
- // and delete operator when the Python wrapper is garbage collected", because
1094
- // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095
- // etc).
1096
- nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1097
- unownedModule->handle = pyRef;
1098
- return PyModuleRef (unownedModule, std::move (pyRef));
1095
+ nb::gil_scoped_acquire acquire;
1096
+ auto &liveModules = contextRef->liveModules ;
1097
+ auto it = liveModules.find (module .ptr );
1098
+ if (it == liveModules.end ()) {
1099
+ // Create.
1100
+ PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1101
+ // Note that the default return value policy on cast is automatic_reference,
1102
+ // which does not take ownership (delete will not be called).
1103
+ // Just be explicit.
1104
+ nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1105
+ unownedModule->handle = pyRef;
1106
+ liveModules[module .ptr ] =
1107
+ std::make_pair (unownedModule->handle , unownedModule);
1108
+ return PyModuleRef (unownedModule, std::move (pyRef));
1109
+ }
1110
+ // Use existing.
1111
+ PyModule *existing = it->second .second ;
1112
+ nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1113
+ return PyModuleRef (existing, std::move (pyRef));
1099
1114
}
1100
1115
1101
1116
nb::object PyModule::createFromCapsule (nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
2084
2099
return PyInsertionPoint{block, std::move (nextOpRef)};
2085
2100
}
2086
2101
2102
+ size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
2103
+
2087
2104
nb::object PyInsertionPoint::contextEnter (nb::object insertPoint) {
2088
2105
return PyThreadContextEntry::pushInsertionPoint (insertPoint);
2089
2106
}
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
2923
2940
PyMlirContextRef ref = PyMlirContext::forContext (self.get ());
2924
2941
return ref.releaseObject ();
2925
2942
})
2943
+ .def (" _get_live_module_count" , &PyMlirContext::getLiveModuleCount)
2926
2944
.def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2927
2945
.def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2928
2946
.def (" __enter__" , &PyMlirContext::contextEnter)
0 commit comments