Skip to content

Commit 6a4f664

Browse files
authored
[MLIR][Python] restore liveModuleMap (#158506)
There are cases where the same module can have multiple references (via `PyModule::forModule` via `PyModule::createFromCapsule`) and thus when `PyModule`s get gc'd `mlirModuleDestroy` can get called multiple times for the same actual underlying `mlir::Module` (i.e., double free). So we do actually need a "liveness map" for modules. Note, if `type_caster<MlirModule>::from_cpp` weren't a thing we could guarantree this never happened except explicitly when users called `PyModule::createFromCapsule`.
1 parent 65ad21d commit 6a4f664

File tree

3 files changed

+47
-15
lines changed

3 files changed

+47
-15
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
10791079
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
10801080
: BaseContextObject(std::move(contextRef)), module(module) {}
10811081

1082-
PyModule::~PyModule() { mlirModuleDestroy(module); }
1082+
PyModule::~PyModule() {
1083+
nb::gil_scoped_acquire acquire;
1084+
auto &liveModules = getContext()->liveModules;
1085+
assert(liveModules.count(module.ptr) == 1 &&
1086+
"destroying module not in live map");
1087+
liveModules.erase(module.ptr);
1088+
mlirModuleDestroy(module);
1089+
}
10831090

10841091
PyModuleRef PyModule::forModule(MlirModule module) {
10851092
MlirContext context = mlirModuleGetContext(module);
10861093
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
10871094

1088-
// Create.
1089-
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1090-
// Note that the default return value policy on cast is `automatic_reference`,
1091-
// which means "does not take ownership, does not call delete/dtor".
1092-
// We use `take_ownership`, which means "Python will call the C++ destructor
1093-
// and delete operator when the Python wrapper is garbage collected", because
1094-
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095-
// etc).
1096-
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1097-
unownedModule->handle = pyRef;
1098-
return PyModuleRef(unownedModule, std::move(pyRef));
1095+
nb::gil_scoped_acquire acquire;
1096+
auto &liveModules = contextRef->liveModules;
1097+
auto it = liveModules.find(module.ptr);
1098+
if (it == liveModules.end()) {
1099+
// Create.
1100+
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1101+
// Note that the default return value policy on cast is automatic_reference,
1102+
// which does not take ownership (delete will not be called).
1103+
// Just be explicit.
1104+
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1105+
unownedModule->handle = pyRef;
1106+
liveModules[module.ptr] =
1107+
std::make_pair(unownedModule->handle, unownedModule);
1108+
return PyModuleRef(unownedModule, std::move(pyRef));
1109+
}
1110+
// Use existing.
1111+
PyModule *existing = it->second.second;
1112+
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1113+
return PyModuleRef(existing, std::move(pyRef));
10991114
}
11001115

11011116
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
20842099
return PyInsertionPoint{block, std::move(nextOpRef)};
20852100
}
20862101

2102+
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
2103+
20872104
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
20882105
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
20892106
}
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
29232940
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
29242941
return ref.releaseObject();
29252942
})
2943+
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
29262944
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
29272945
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
29282946
.def("__enter__", &PyMlirContext::contextEnter)

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ class PyMlirContext {
218218
/// Gets the count of live context objects. Used for testing.
219219
static size_t getLiveCount();
220220

221+
/// Gets the count of live modules associated with this context.
222+
/// Used for testing.
223+
size_t getLiveModuleCount();
224+
221225
/// Enter and exit the context manager.
222226
static nanobind::object contextEnter(nanobind::object context);
223227
void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ class PyMlirContext {
244248
static nanobind::ft_mutex live_contexts_mutex;
245249
static LiveContextMap &getLiveContexts();
246250

251+
// Interns all live modules associated with this context. Modules tracked
252+
// in this map are valid. When a module is invalidated, it is removed
253+
// from this map, and while it still exists as an instance, any
254+
// attempt to access it will raise an error.
255+
using LiveModuleMap =
256+
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
257+
LiveModuleMap liveModules;
258+
247259
bool emitErrorDiagnostics = false;
248260

249261
MlirContext context;

mlir/test/python/ir/module.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def testRoundtripBinary():
121121
def testModuleOperation():
122122
ctx = Context()
123123
module = Module.parse(r"""module @successfulParse {}""", ctx)
124+
assert ctx._get_live_module_count() == 1
124125
op1 = module.operation
125126
# CHECK: module @successfulParse
126127
print(op1)
@@ -145,24 +146,25 @@ def testModuleOperation():
145146
op1 = None
146147
op2 = None
147148
gc.collect()
149+
assert ctx._get_live_module_count() == 0
148150

149151

150152
# CHECK-LABEL: TEST: testModuleCapsule
151153
@run
152154
def testModuleCapsule():
153155
ctx = Context()
154156
module = Module.parse(r"""module @successfulParse {}""", ctx)
157+
assert ctx._get_live_module_count() == 1
155158
# CHECK: "mlir.ir.Module._CAPIPtr"
156159
module_capsule = module._CAPIPtr
157160
print(module_capsule)
158161
module_dup = Module._CAPICreate(module_capsule)
159-
assert module is not module_dup
162+
assert module is module_dup
160163
assert module == module_dup
161-
module._clear_mlir_module()
162-
assert module != module_dup
163164
assert module_dup.context is ctx
164165
# Gc and verify destructed.
165166
module = None
166167
module_capsule = None
167168
module_dup = None
168169
gc.collect()
170+
assert ctx._get_live_module_count() == 0

0 commit comments

Comments
 (0)