Skip to content

Commit 77389e4

Browse files
authored
[MLIR][Python] remove liveOperations (#155114)
Historical context: `PyMlirContext::liveOperations` was an optimization meant to cut down on the number of Python object allocations and (partially) a mechanism for updating validity of ops after transformation. E.g. during walking/transforming the AST. See original patch [here](https://reviews.llvm.org/D87958). Inspired by a [renewed](llvm/llvm-project#139721 (comment)) interest in llvm/llvm-project#139721 (which has become a little stale...) <p align="center"> <img width="504" height="375" alt="image" src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186" /> </p> In the previous go-around (llvm/llvm-project#92631) there were two issues which have been resolved 1. ops that were "fetched" under a root op which has been transformed are no longer reported as invalid. We simply "[formally forbid](llvm/llvm-project#92631 (comment))" this; 2. `Module._CAPICreate(module_capsule)` must now be followed by a `module._clear_mlir_module()` to prevent double-freeing of the actual `ModuleOp` object (i.e. calling the dtor on the `OwningOpRef<ModuleOp>`): ```python module = ... module_dup = Module._CAPICreate(module._CAPIPtr) module._clear_mlir_module() ``` - **the alternative choice** here is to remove the `Module._CAPICreate` API altogether and replace it with something like `Module._move(module)` which will do both `Module._CAPICreate` and `module._clear_mlir_module`. Note, the other approach I explored last year was a [weakref system](llvm/llvm-project#97340) for `mlir::Operation` which would effectively hoist this `liveOperations` thing into MLIR core. Possibly doable but I now believe it's a bad idea. The other potentially breaking change is `is`, which checks object equality rather than value equality, will now report `False` because we are always allocating `new` Python objects (ie that's the whole point of this change). Users wanting to check equality for `Operation` and `Module` should use `==`.
1 parent 03999f8 commit 77389e4

File tree

6 files changed

+61
-215
lines changed

6 files changed

+61
-215
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
415415
/// The returned module is null when the input operation was not a ModuleOp.
416416
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
417417

418+
/// Checks if two modules are equal.
419+
MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
420+
418421
//===----------------------------------------------------------------------===//
419422
// Operation state.
420423
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 46 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
6767
See also: https://mlir.llvm.org/docs/LangRef/
6868
)";
6969

70+
static const char kModuleCAPICreate[] =
71+
R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
72+
Note this returns a new object BUT _clear_mlir_module(module) must be called to
73+
prevent double-frees (of the underlying mlir::Module).
74+
)";
75+
7076
static const char kOperationCreateDocstring[] =
7177
R"(Creates a new operation.
7278
@@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
702708
return getLiveContexts().size();
703709
}
704710

705-
size_t PyMlirContext::getLiveOperationCount() {
706-
nb::ft_lock_guard lock(liveOperationsMutex);
707-
return liveOperations.size();
708-
}
709-
710-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
711-
std::vector<PyOperation *> liveObjects;
712-
nb::ft_lock_guard lock(liveOperationsMutex);
713-
for (auto &entry : liveOperations)
714-
liveObjects.push_back(entry.second.second);
715-
return liveObjects;
716-
}
717-
718-
size_t PyMlirContext::clearLiveOperations() {
719-
720-
LiveOperationMap operations;
721-
{
722-
nb::ft_lock_guard lock(liveOperationsMutex);
723-
std::swap(operations, liveOperations);
724-
}
725-
for (auto &op : operations)
726-
op.second.second->setInvalid();
727-
size_t numInvalidated = operations.size();
728-
return numInvalidated;
729-
}
730-
731-
void PyMlirContext::clearOperation(MlirOperation op) {
732-
PyOperation *pyOp;
733-
{
734-
nb::ft_lock_guard lock(liveOperationsMutex);
735-
auto it = liveOperations.find(op.ptr);
736-
if (it == liveOperations.end()) {
737-
return;
738-
}
739-
pyOp = it->second.second;
740-
liveOperations.erase(it);
741-
}
742-
pyOp->setInvalid();
743-
}
744-
745-
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
746-
using callBackData = struct {
747-
PyOperation &rootOp;
748-
bool rootSeen;
749-
};
750-
callBackData data{op.getOperation(), false};
751-
// Mark all ops below the op that the passmanager will be rooted
752-
// at (but not op itself - note the preorder) as invalid.
753-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754-
void *userData) {
755-
callBackData *data = static_cast<callBackData *>(userData);
756-
if (LLVM_LIKELY(data->rootSeen))
757-
data->rootOp.getOperation().getContext()->clearOperation(op);
758-
else
759-
data->rootSeen = true;
760-
return MlirWalkResult::MlirWalkResultAdvance;
761-
};
762-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
763-
static_cast<void *>(&data), MlirWalkPreOrder);
764-
}
765-
void PyMlirContext::clearOperationsInside(MlirOperation op) {
766-
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
767-
clearOperationsInside(opRef->getOperation());
768-
}
769-
770-
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
771-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772-
void *userData) {
773-
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
774-
contextRef->clearOperation(op);
775-
return MlirWalkResult::MlirWalkResultAdvance;
776-
};
777-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
778-
&op.getOperation().getContext(), MlirWalkPreOrder);
779-
}
780-
781-
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
782-
783711
nb::object PyMlirContext::contextEnter(nb::object context) {
784712
return PyThreadContextEntry::pushContext(context);
785713
}
@@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
11511079
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
11521080
: BaseContextObject(std::move(contextRef)), module(module) {}
11531081

1154-
PyModule::~PyModule() {
1155-
nb::gil_scoped_acquire acquire;
1156-
auto &liveModules = getContext()->liveModules;
1157-
assert(liveModules.count(module.ptr) == 1 &&
1158-
"destroying module not in live map");
1159-
liveModules.erase(module.ptr);
1160-
mlirModuleDestroy(module);
1161-
}
1082+
PyModule::~PyModule() { mlirModuleDestroy(module); }
11621083

11631084
PyModuleRef PyModule::forModule(MlirModule module) {
11641085
MlirContext context = mlirModuleGetContext(module);
11651086
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
11661087

1167-
nb::gil_scoped_acquire acquire;
1168-
auto &liveModules = contextRef->liveModules;
1169-
auto it = liveModules.find(module.ptr);
1170-
if (it == liveModules.end()) {
1171-
// Create.
1172-
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1173-
// Note that the default return value policy on cast is automatic_reference,
1174-
// which does not take ownership (delete will not be called).
1175-
// Just be explicit.
1176-
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1177-
unownedModule->handle = pyRef;
1178-
liveModules[module.ptr] =
1179-
std::make_pair(unownedModule->handle, unownedModule);
1180-
return PyModuleRef(unownedModule, std::move(pyRef));
1181-
}
1182-
// Use existing.
1183-
PyModule *existing = it->second.second;
1184-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1185-
return PyModuleRef(existing, std::move(pyRef));
1088+
// Create.
1089+
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1090+
// Note that the default return value policy on cast is `automatic_reference`,
1091+
// which means "does not take ownership, does not call delete/dtor".
1092+
// We use `take_ownership`, which means "Python will call the C++ destructor
1093+
// and delete operator when the Python wrapper is garbage collected", because
1094+
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095+
// etc).
1096+
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1097+
unownedModule->handle = pyRef;
1098+
return PyModuleRef(unownedModule, std::move(pyRef));
11861099
}
11871100

11881101
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() {
12071120
// If the operation has already been invalidated there is nothing to do.
12081121
if (!valid)
12091122
return;
1210-
1211-
// Otherwise, invalidate the operation and remove it from live map when it is
1212-
// attached.
1213-
if (isAttached()) {
1214-
getContext()->clearOperation(*this);
1215-
} else {
1216-
// And destroy it when it is detached, i.e. owned by Python, in which case
1217-
// all nested operations must be invalidated at removed from the live map as
1218-
// well.
1123+
// Otherwise, invalidate the operation when it is attached.
1124+
if (isAttached())
1125+
setInvalid();
1126+
else {
1127+
// And destroy it when it is detached, i.e. owned by Python.
12191128
erase();
12201129
}
12211130
}
@@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12521161
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12531162
MlirOperation operation,
12541163
nb::object parentKeepAlive) {
1255-
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1256-
auto &liveOperations = contextRef->liveOperations;
1257-
auto it = liveOperations.find(operation.ptr);
1258-
if (it == liveOperations.end()) {
1259-
// Create.
1260-
PyOperationRef result = createInstance(std::move(contextRef), operation,
1261-
std::move(parentKeepAlive));
1262-
liveOperations[operation.ptr] =
1263-
std::make_pair(result.getObject(), result.get());
1264-
return result;
1265-
}
1266-
// Use existing.
1267-
PyOperation *existing = it->second.second;
1268-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1269-
return PyOperationRef(existing, std::move(pyRef));
1164+
return createInstance(std::move(contextRef), operation,
1165+
std::move(parentKeepAlive));
12701166
}
12711167

12721168
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
12731169
MlirOperation operation,
12741170
nb::object parentKeepAlive) {
1275-
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1276-
auto &liveOperations = contextRef->liveOperations;
1277-
assert(liveOperations.count(operation.ptr) == 0 &&
1278-
"cannot create detached operation that already exists");
1279-
(void)liveOperations;
12801171
PyOperationRef created = createInstance(std::move(contextRef), operation,
12811172
std::move(parentKeepAlive));
1282-
liveOperations[operation.ptr] =
1283-
std::make_pair(created.getObject(), created.get());
12841173
created->attached = false;
12851174
return created;
12861175
}
@@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() {
16521541

16531542
void PyOperation::erase() {
16541543
checkValid();
1655-
getContext()->clearOperationAndInside(*this);
1544+
setInvalid();
16561545
mlirOperationDestroy(operation);
16571546
}
16581547

@@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30232912
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
30242913
return ref.releaseObject();
30252914
})
3026-
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
3027-
.def("_get_live_operation_objects",
3028-
&PyMlirContext::getLiveOperationObjects)
3029-
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
3030-
.def("_clear_live_operations_inside",
3031-
nb::overload_cast<MlirOperation>(
3032-
&PyMlirContext::clearOperationsInside))
3033-
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
30342915
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
30352916
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
30362917
.def("__enter__", &PyMlirContext::contextEnter)
@@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
33483229
//----------------------------------------------------------------------------
33493230
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
33503231
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3351-
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3232+
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
3233+
kModuleCAPICreate)
3234+
.def("_clear_mlir_module", &PyModule::clearMlirModule)
33523235
.def_static(
33533236
"parse",
33543237
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34283311
// Defer to the operation's __str__.
34293312
return self.attr("operation").attr("__str__")();
34303313
},
3431-
kOperationStrDunderDocstring);
3314+
kOperationStrDunderDocstring)
3315+
.def(
3316+
"__eq__",
3317+
[](PyModule &self, PyModule &other) {
3318+
return mlirModuleEqual(self.get(), other.get());
3319+
},
3320+
"other"_a);
34323321

34333322
//----------------------------------------------------------------------------
34343323
// Mapping of Operation.
@@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34403329
})
34413330
.def("__eq__",
34423331
[](PyOperationBase &self, PyOperationBase &other) {
3443-
return &self.getOperation() == &other.getOperation();
3332+
return mlirOperationEqual(self.getOperation().get(),
3333+
other.getOperation().get());
34443334
})
34453335
.def("__eq__",
34463336
[](PyOperationBase &self, nb::object other) { return false; })
@@ -3655,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553545
[](PyOperationBase &self) {
36563546
return PyOpSuccessors(self.getOperation().getRef());
36573547
},
3658-
"Returns the list of Operation successors.");
3548+
"Returns the list of Operation successors.")
3549+
.def("_set_invalid", &PyOperation::setInvalid,
3550+
"Invalidate the operation.");
36593551

36603552
auto opViewClass =
36613553
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
@@ -3699,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36993591
[](PyOperationBase &self) {
37003592
return PyOpSuccessors(self.getOperation().getRef());
37013593
},
3702-
"Returns the list of Operation successors.");
3594+
"Returns the list of Operation successors.")
3595+
.def(
3596+
"_set_invalid",
3597+
[](PyOpView &self) { self.getOperation().setInvalid(); },
3598+
"Invalidate the operation.");
37033599
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
37043600
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
37053601
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -218,40 +218,6 @@ class PyMlirContext {
218218
/// Gets the count of live context objects. Used for testing.
219219
static size_t getLiveCount();
220220

221-
/// Get a list of Python objects which are still in the live context map.
222-
std::vector<PyOperation *> getLiveOperationObjects();
223-
224-
/// Gets the count of live operations associated with this context.
225-
/// Used for testing.
226-
size_t getLiveOperationCount();
227-
228-
/// Clears the live operations map, returning the number of entries which were
229-
/// invalidated. To be used as a safety mechanism so that API end-users can't
230-
/// corrupt by holding references they shouldn't have accessed in the first
231-
/// place.
232-
size_t clearLiveOperations();
233-
234-
/// Removes an operation from the live operations map and sets it invalid.
235-
/// This is useful for when some non-bindings code destroys the operation and
236-
/// the bindings need to made aware. For example, in the case when pass
237-
/// manager is run.
238-
///
239-
/// Note that this does *NOT* clear the nested operations.
240-
void clearOperation(MlirOperation op);
241-
242-
/// Clears all operations nested inside the given op using
243-
/// `clearOperation(MlirOperation)`.
244-
void clearOperationsInside(PyOperationBase &op);
245-
void clearOperationsInside(MlirOperation op);
246-
247-
/// Clears the operaiton _and_ all operations inside using
248-
/// `clearOperation(MlirOperation)`.
249-
void clearOperationAndInside(PyOperationBase &op);
250-
251-
/// Gets the count of live modules associated with this context.
252-
/// Used for testing.
253-
size_t getLiveModuleCount();
254-
255221
/// Enter and exit the context manager.
256222
static nanobind::object contextEnter(nanobind::object context);
257223
void contextExit(const nanobind::object &excType,
@@ -278,25 +244,6 @@ class PyMlirContext {
278244
static nanobind::ft_mutex live_contexts_mutex;
279245
static LiveContextMap &getLiveContexts();
280246

281-
// Interns all live modules associated with this context. Modules tracked
282-
// in this map are valid. When a module is invalidated, it is removed
283-
// from this map, and while it still exists as an instance, any
284-
// attempt to access it will raise an error.
285-
using LiveModuleMap =
286-
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
287-
LiveModuleMap liveModules;
288-
289-
// Interns all live operations associated with this context. Operations
290-
// tracked in this map are valid. When an operation is invalidated, it is
291-
// removed from this map, and while it still exists as an instance, any
292-
// attempt to access it will raise an error.
293-
using LiveOperationMap =
294-
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
295-
nanobind::ft_mutex liveOperationsMutex;
296-
297-
// Guarded by liveOperationsMutex in free-threading mode.
298-
LiveOperationMap liveOperations;
299-
300247
bool emitErrorDiagnostics = false;
301248

302249
MlirContext context;
@@ -548,8 +495,8 @@ class PyModule;
548495
using PyModuleRef = PyObjectRef<PyModule>;
549496
class PyModule : public BaseContextObject {
550497
public:
551-
/// Returns a PyModule reference for the given MlirModule. This may return
552-
/// a pre-existing or new object.
498+
/// Returns a PyModule reference for the given MlirModule. This always returns
499+
/// a new object.
553500
static PyModuleRef forModule(MlirModule module);
554501
PyModule(PyModule &) = delete;
555502
PyModule(PyMlirContext &&) = delete;
@@ -570,11 +517,12 @@ class PyModule : public BaseContextObject {
570517
nanobind::object getCapsule();
571518

572519
/// Creates a PyModule from the MlirModule wrapped by a capsule.
573-
/// Note that PyModule instances are uniqued, so the returned object
574-
/// may be a pre-existing object. Ownership of the underlying MlirModule
575-
/// is taken by calling this function.
520+
/// Note this returns a new object BUT clearMlirModule() must be called to
521+
/// prevent double-frees (of the underlying mlir::Module).
576522
static nanobind::object createFromCapsule(nanobind::object capsule);
577523

524+
void clearMlirModule() { module = {nullptr}; }
525+
578526
private:
579527
PyModule(PyMlirContextRef contextRef, MlirModule module);
580528
MlirModule module;

0 commit comments

Comments
 (0)