Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit ed07158

Browse files
authored
[mlir][py] invalidate nested operations when parent is deleted (#93339)
When an operation is erased in Python, its children may still be in the "live" list inside Python bindings. After this, if some of the newly allocated operations happen to reuse the same pointer address, this will trigger an assertion in the bindings. This assertion would be incorrect because the operations aren't actually live. Make sure we remove the children operations from the "live" list when erasing the parent. This also concentrates responsibility over the removal from the "live" list and invalidation in a single place. Note that this requires the IR to be sufficiently structurally valid so a walk through it can succeed. If this invariant was broken by, e.g, C++ pass called from Python, there isn't much we can do.
1 parent 9b4fb92 commit ed07158

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

IRCore.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
697697
clearOperationsInside(opRef->getOperation());
698698
}
699699

700+
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
701+
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
702+
void *userData) {
703+
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
704+
contextRef->clearOperation(op);
705+
return MlirWalkResult::MlirWalkResultAdvance;
706+
};
707+
mlirOperationWalk(op.getOperation(), invalidatingCallback,
708+
&op.getOperation().getContext(), MlirWalkPreOrder);
709+
}
710+
700711
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
701712

702713
pybind11::object PyMlirContext::contextEnter() {
@@ -1125,12 +1136,16 @@ PyOperation::~PyOperation() {
11251136
// If the operation has already been invalidated there is nothing to do.
11261137
if (!valid)
11271138
return;
1128-
auto &liveOperations = getContext()->liveOperations;
1129-
assert(liveOperations.count(operation.ptr) == 1 &&
1130-
"destroying operation not in live map");
1131-
liveOperations.erase(operation.ptr);
1132-
if (!isAttached()) {
1133-
mlirOperationDestroy(operation);
1139+
1140+
// Otherwise, invalidate the operation and remove it from live map when it is
1141+
// attached.
1142+
if (isAttached()) {
1143+
getContext()->clearOperation(*this);
1144+
} else {
1145+
// And destroy it when it is detached, i.e. owned by Python, in which case
1146+
// all nested operations must be invalidated at removed from the live map as
1147+
// well.
1148+
erase();
11341149
}
11351150
}
11361151

@@ -1540,14 +1555,8 @@ py::object PyOperation::createOpView() {
15401555

15411556
void PyOperation::erase() {
15421557
checkValid();
1543-
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
1544-
// Python reference to a child operation is live. All children should also
1545-
// have their `valid` bit set to false.
1546-
auto &liveOperations = getContext()->liveOperations;
1547-
if (liveOperations.count(operation.ptr))
1548-
liveOperations.erase(operation.ptr);
1558+
getContext()->clearOperationAndInside(*this);
15491559
mlirOperationDestroy(operation);
1550-
valid = false;
15511560
}
15521561

15531562
//------------------------------------------------------------------------------

IRModule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,19 @@ class PyMlirContext {
218218
/// This is useful for when some non-bindings code destroys the operation and
219219
/// the bindings need to made aware. For example, in the case when pass
220220
/// manager is run.
221+
///
222+
/// Note that this does *NOT* clear the nested operations.
221223
void clearOperation(MlirOperation op);
222224

223225
/// Clears all operations nested inside the given op using
224226
/// `clearOperation(MlirOperation)`.
225227
void clearOperationsInside(PyOperationBase &op);
226228
void clearOperationsInside(MlirOperation op);
227229

230+
/// Clears the operaiton _and_ all operations inside using
231+
/// `clearOperation(MlirOperation)`.
232+
void clearOperationAndInside(PyOperationBase &op);
233+
228234
/// Gets the count of live modules associated with this context.
229235
/// Used for testing.
230236
size_t getLiveModuleCount();
@@ -246,6 +252,7 @@ class PyMlirContext {
246252

247253
private:
248254
PyMlirContext(MlirContext context);
255+
249256
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
250257
// preserving the relationship that an MlirContext maps to a single
251258
// PyMlirContext wrapper. This could be replaced in the future with an

0 commit comments

Comments
 (0)