Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}

PyModule::~PyModule() { mlirModuleDestroy(module); }
PyModule::~PyModule() {
nb::gil_scoped_acquire acquire;
auto &liveModules = getContext()->liveModules;
assert(liveModules.count(module.ptr) == 1 &&
"destroying module not in live map");
liveModules.erase(module.ptr);
mlirModuleDestroy(module);
}

PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);

// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is `automatic_reference`,
// which means "does not take ownership, does not call delete/dtor".
// We use `take_ownership`, which means "Python will call the C++ destructor
// and delete operator when the Python wrapper is garbage collected", because
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
// etc).
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
return PyModuleRef(unownedModule, std::move(pyRef));
nb::gil_scoped_acquire acquire;
auto &liveModules = contextRef->liveModules;
auto it = liveModules.find(module.ptr);
if (it == liveModules.end()) {
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
liveModules[module.ptr] =
std::make_pair(unownedModule->handle, unownedModule);
return PyModuleRef(unownedModule, std::move(pyRef));
}
// Use existing.
PyModule *existing = it->second.second;
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyModuleRef(existing, std::move(pyRef));
}

nb::object PyModule::createFromCapsule(nb::object capsule) {
Expand Down Expand Up @@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
return PyInsertionPoint{block, std::move(nextOpRef)};
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
Expand Down Expand Up @@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ class PyMlirContext {
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();

/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();

/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
Expand All @@ -244,6 +248,14 @@ class PyMlirContext {
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();

// Interns all live modules associated with this context. Modules tracked
// in this map are valid. When a module is invalidated, it is removed
// from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveModuleMap =
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
LiveModuleMap liveModules;

bool emitErrorDiagnostics = false;

MlirContext context;
Expand Down
8 changes: 5 additions & 3 deletions mlir/test/python/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
op1 = module.operation
# CHECK: module @successfulParse
print(op1)
Expand All @@ -145,24 +146,25 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
assert ctx._get_live_module_count() == 0


# CHECK-LABEL: TEST: testModuleCapsule
@run
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
assert module is not module_dup
assert module is module_dup
assert module == module_dup
module._clear_mlir_module()
assert module != module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
assert ctx._get_live_module_count() == 0
Loading