Skip to content

Commit 4d470c4

Browse files
maksleventalgithub-actions[bot]
authored andcommitted
Automerge: [MLIR][Python] remove liveOperations (#155114)
Historical context: `PyMlirContext::liveOperations` was an optimization meant to cut down on the number of Python object allocations and (partially) a mechanism for updating validity of ops after transformation. E.g. during walking/transforming the AST. See original patch [here](https://reviews.llvm.org/D87958). Inspired by a [renewed](llvm/llvm-project#139721 (comment)) interest in llvm/llvm-project#139721 (which has become a little stale...) <p align="center"> <img width="504" height="375" alt="image" src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186" /> </p> In the previous go-around (llvm/llvm-project#92631) there were two issues which have been resolved 1. ops that were "fetched" under a root op which has been transformed are no longer reported as invalid. We simply "[formally forbid](llvm/llvm-project#92631 (comment))" this; 2. `Module._CAPICreate(module_capsule)` must now be followed by a `module._clear_mlir_module()` to prevent double-freeing of the actual `ModuleOp` object (i.e. calling the dtor on the `OwningOpRef<ModuleOp>`): ```python module = ... module_dup = Module._CAPICreate(module._CAPIPtr) module._clear_mlir_module() ``` - **the alternative choice** here is to remove the `Module._CAPICreate` API altogether and replace it with something like `Module._move(module)` which will do both `Module._CAPICreate` and `module._clear_mlir_module`. Note, the other approach I explored last year was a [weakref system](llvm/llvm-project#97340) for `mlir::Operation` which would effectively hoist this `liveOperations` thing into MLIR core. Possibly doable but I now believe it's a bad idea. The other potentially breaking change is `is`, which checks object equality rather than value equality, will now report `False` because we are always allocating `new` Python objects (ie that's the whole point of this change). Users wanting to check equality for `Operation` and `Module` should use `==`.
2 parents 655407b + b2a7369 commit 4d470c4

File tree

11 files changed

+102
-276
lines changed

11 files changed

+102
-276
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,28 @@ added to an attached operation, they need to be re-parented to the containing
216216
module).
217217

218218
Due to the validity and parenting accounting needs, `PyOperation` is the owner
219-
for regions and blocks and needs to be a top-level type that we can count on not
220-
aliasing. This let's us do things like selectively invalidating instances when
221-
mutations occur without worrying that there is some alias to the same operation
222-
in the hierarchy. Operations are also the only entity that are allowed to be in
223-
a detached state, and they are interned at the context level so that there is
224-
never more than one Python `mlir.ir.Operation` object for a unique
225-
`MlirOperation`, regardless of how it is obtained.
219+
for regions and blocks. Operations are also the only entities which are allowed to be in
220+
a detached state.
221+
222+
**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`.
223+
This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op`
224+
and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2`
225+
will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any
226+
operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**.
227+
For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is
228+
transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2`
229+
become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to
230+
`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the
231+
purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that
232+
233+
1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.;
234+
2. Second, transform the AST/erase operations/etc. via a single root object;
235+
3. Invalidate all queried nodes (e.g., using `op._set_invalid()`).
236+
237+
Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no
238+
risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on
239+
nesting i.e., change leaf nodes first before going up in hierarchy, and only in very rare cases query nested ops post
240+
modifying a parent op.
226241

227242
The C/C++ API allows for Region/Block to also be detached, but it simplifies the
228243
ownership model a lot to eliminate that possibility in this API, allowing the
@@ -238,11 +253,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on
238253
how hard it is to guarantee how mutations interact with their Python peer
239254
objects. We can cross that bridge easily when we get there.
240255

241-
Module, when used purely from the Python API, can't alias anyway, so we can use
242-
it as a top-level ref type without a live-list for interning. If the API ever
243-
changes such that this cannot be guaranteed (i.e. by letting you marshal a
244-
native-defined Module in), then there would need to be a live table for it too.
245-
246256
## User-level API
247257

248258
### Context Management
@@ -1229,4 +1239,4 @@ The exceptions to the free-threading compatibility:
12291239
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
12301240
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
12311241
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
1232-
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
1242+
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.

mlir/include/mlir-c/IR.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
415415
/// The returned module is null when the input operation was not a ModuleOp.
416416
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
417417

418+
/// Checks if two modules are equal.
419+
MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
420+
418421
//===----------------------------------------------------------------------===//
419422
// Operation state.
420423
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 46 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
6767
See also: https://mlir.llvm.org/docs/LangRef/
6868
)";
6969

70+
static const char kModuleCAPICreate[] =
71+
R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
72+
Note this returns a new object BUT _clear_mlir_module(module) must be called to
73+
prevent double-frees (of the underlying mlir::Module).
74+
)";
75+
7076
static const char kOperationCreateDocstring[] =
7177
R"(Creates a new operation.
7278
@@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
702708
return getLiveContexts().size();
703709
}
704710

705-
size_t PyMlirContext::getLiveOperationCount() {
706-
nb::ft_lock_guard lock(liveOperationsMutex);
707-
return liveOperations.size();
708-
}
709-
710-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
711-
std::vector<PyOperation *> liveObjects;
712-
nb::ft_lock_guard lock(liveOperationsMutex);
713-
for (auto &entry : liveOperations)
714-
liveObjects.push_back(entry.second.second);
715-
return liveObjects;
716-
}
717-
718-
size_t PyMlirContext::clearLiveOperations() {
719-
720-
LiveOperationMap operations;
721-
{
722-
nb::ft_lock_guard lock(liveOperationsMutex);
723-
std::swap(operations, liveOperations);
724-
}
725-
for (auto &op : operations)
726-
op.second.second->setInvalid();
727-
size_t numInvalidated = operations.size();
728-
return numInvalidated;
729-
}
730-
731-
void PyMlirContext::clearOperation(MlirOperation op) {
732-
PyOperation *pyOp;
733-
{
734-
nb::ft_lock_guard lock(liveOperationsMutex);
735-
auto it = liveOperations.find(op.ptr);
736-
if (it == liveOperations.end()) {
737-
return;
738-
}
739-
pyOp = it->second.second;
740-
liveOperations.erase(it);
741-
}
742-
pyOp->setInvalid();
743-
}
744-
745-
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
746-
using callBackData = struct {
747-
PyOperation &rootOp;
748-
bool rootSeen;
749-
};
750-
callBackData data{op.getOperation(), false};
751-
// Mark all ops below the op that the passmanager will be rooted
752-
// at (but not op itself - note the preorder) as invalid.
753-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754-
void *userData) {
755-
callBackData *data = static_cast<callBackData *>(userData);
756-
if (LLVM_LIKELY(data->rootSeen))
757-
data->rootOp.getOperation().getContext()->clearOperation(op);
758-
else
759-
data->rootSeen = true;
760-
return MlirWalkResult::MlirWalkResultAdvance;
761-
};
762-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
763-
static_cast<void *>(&data), MlirWalkPreOrder);
764-
}
765-
void PyMlirContext::clearOperationsInside(MlirOperation op) {
766-
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
767-
clearOperationsInside(opRef->getOperation());
768-
}
769-
770-
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
771-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772-
void *userData) {
773-
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
774-
contextRef->clearOperation(op);
775-
return MlirWalkResult::MlirWalkResultAdvance;
776-
};
777-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
778-
&op.getOperation().getContext(), MlirWalkPreOrder);
779-
}
780-
781-
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
782-
783711
nb::object PyMlirContext::contextEnter(nb::object context) {
784712
return PyThreadContextEntry::pushContext(context);
785713
}
@@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
11511079
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
11521080
: BaseContextObject(std::move(contextRef)), module(module) {}
11531081

1154-
PyModule::~PyModule() {
1155-
nb::gil_scoped_acquire acquire;
1156-
auto &liveModules = getContext()->liveModules;
1157-
assert(liveModules.count(module.ptr) == 1 &&
1158-
"destroying module not in live map");
1159-
liveModules.erase(module.ptr);
1160-
mlirModuleDestroy(module);
1161-
}
1082+
PyModule::~PyModule() { mlirModuleDestroy(module); }
11621083

11631084
PyModuleRef PyModule::forModule(MlirModule module) {
11641085
MlirContext context = mlirModuleGetContext(module);
11651086
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
11661087

1167-
nb::gil_scoped_acquire acquire;
1168-
auto &liveModules = contextRef->liveModules;
1169-
auto it = liveModules.find(module.ptr);
1170-
if (it == liveModules.end()) {
1171-
// Create.
1172-
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1173-
// Note that the default return value policy on cast is automatic_reference,
1174-
// which does not take ownership (delete will not be called).
1175-
// Just be explicit.
1176-
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1177-
unownedModule->handle = pyRef;
1178-
liveModules[module.ptr] =
1179-
std::make_pair(unownedModule->handle, unownedModule);
1180-
return PyModuleRef(unownedModule, std::move(pyRef));
1181-
}
1182-
// Use existing.
1183-
PyModule *existing = it->second.second;
1184-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1185-
return PyModuleRef(existing, std::move(pyRef));
1088+
// Create.
1089+
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1090+
// Note that the default return value policy on cast is `automatic_reference`,
1091+
// which means "does not take ownership, does not call delete/dtor".
1092+
// We use `take_ownership`, which means "Python will call the C++ destructor
1093+
// and delete operator when the Python wrapper is garbage collected", because
1094+
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095+
// etc).
1096+
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1097+
unownedModule->handle = pyRef;
1098+
return PyModuleRef(unownedModule, std::move(pyRef));
11861099
}
11871100

11881101
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() {
12071120
// If the operation has already been invalidated there is nothing to do.
12081121
if (!valid)
12091122
return;
1210-
1211-
// Otherwise, invalidate the operation and remove it from live map when it is
1212-
// attached.
1213-
if (isAttached()) {
1214-
getContext()->clearOperation(*this);
1215-
} else {
1216-
// And destroy it when it is detached, i.e. owned by Python, in which case
1217-
// all nested operations must be invalidated at removed from the live map as
1218-
// well.
1123+
// Otherwise, invalidate the operation when it is attached.
1124+
if (isAttached())
1125+
setInvalid();
1126+
else {
1127+
// And destroy it when it is detached, i.e. owned by Python.
12191128
erase();
12201129
}
12211130
}
@@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12521161
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12531162
MlirOperation operation,
12541163
nb::object parentKeepAlive) {
1255-
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1256-
auto &liveOperations = contextRef->liveOperations;
1257-
auto it = liveOperations.find(operation.ptr);
1258-
if (it == liveOperations.end()) {
1259-
// Create.
1260-
PyOperationRef result = createInstance(std::move(contextRef), operation,
1261-
std::move(parentKeepAlive));
1262-
liveOperations[operation.ptr] =
1263-
std::make_pair(result.getObject(), result.get());
1264-
return result;
1265-
}
1266-
// Use existing.
1267-
PyOperation *existing = it->second.second;
1268-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1269-
return PyOperationRef(existing, std::move(pyRef));
1164+
return createInstance(std::move(contextRef), operation,
1165+
std::move(parentKeepAlive));
12701166
}
12711167

12721168
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
12731169
MlirOperation operation,
12741170
nb::object parentKeepAlive) {
1275-
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1276-
auto &liveOperations = contextRef->liveOperations;
1277-
assert(liveOperations.count(operation.ptr) == 0 &&
1278-
"cannot create detached operation that already exists");
1279-
(void)liveOperations;
12801171
PyOperationRef created = createInstance(std::move(contextRef), operation,
12811172
std::move(parentKeepAlive));
1282-
liveOperations[operation.ptr] =
1283-
std::make_pair(created.getObject(), created.get());
12841173
created->attached = false;
12851174
return created;
12861175
}
@@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() {
16521541

16531542
void PyOperation::erase() {
16541543
checkValid();
1655-
getContext()->clearOperationAndInside(*this);
1544+
setInvalid();
16561545
mlirOperationDestroy(operation);
16571546
}
16581547

@@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30232912
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
30242913
return ref.releaseObject();
30252914
})
3026-
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
3027-
.def("_get_live_operation_objects",
3028-
&PyMlirContext::getLiveOperationObjects)
3029-
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
3030-
.def("_clear_live_operations_inside",
3031-
nb::overload_cast<MlirOperation>(
3032-
&PyMlirContext::clearOperationsInside))
3033-
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
30342915
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
30352916
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
30362917
.def("__enter__", &PyMlirContext::contextEnter)
@@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
33483229
//----------------------------------------------------------------------------
33493230
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
33503231
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3351-
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3232+
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
3233+
kModuleCAPICreate)
3234+
.def("_clear_mlir_module", &PyModule::clearMlirModule)
33523235
.def_static(
33533236
"parse",
33543237
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34283311
// Defer to the operation's __str__.
34293312
return self.attr("operation").attr("__str__")();
34303313
},
3431-
kOperationStrDunderDocstring);
3314+
kOperationStrDunderDocstring)
3315+
.def(
3316+
"__eq__",
3317+
[](PyModule &self, PyModule &other) {
3318+
return mlirModuleEqual(self.get(), other.get());
3319+
},
3320+
"other"_a);
34323321

34333322
//----------------------------------------------------------------------------
34343323
// Mapping of Operation.
@@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34403329
})
34413330
.def("__eq__",
34423331
[](PyOperationBase &self, PyOperationBase &other) {
3443-
return &self.getOperation() == &other.getOperation();
3332+
return mlirOperationEqual(self.getOperation().get(),
3333+
other.getOperation().get());
34443334
})
34453335
.def("__eq__",
34463336
[](PyOperationBase &self, nb::object other) { return false; })
@@ -3655,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553545
[](PyOperationBase &self) {
36563546
return PyOpSuccessors(self.getOperation().getRef());
36573547
},
3658-
"Returns the list of Operation successors.");
3548+
"Returns the list of Operation successors.")
3549+
.def("_set_invalid", &PyOperation::setInvalid,
3550+
"Invalidate the operation.");
36593551

36603552
auto opViewClass =
36613553
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
@@ -3699,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36993591
[](PyOperationBase &self) {
37003592
return PyOpSuccessors(self.getOperation().getRef());
37013593
},
3702-
"Returns the list of Operation successors.");
3594+
"Returns the list of Operation successors.")
3595+
.def(
3596+
"_set_invalid",
3597+
[](PyOpView &self) { self.getOperation().setInvalid(); },
3598+
"Invalidate the operation.");
37033599
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
37043600
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
37053601
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();

0 commit comments

Comments
 (0)