Skip to content

Commit 7bd311d

Browse files
committed
Fix race in MLIR Python bindings.
1 parent 2d287f5 commit 7bd311d

File tree

2 files changed

+156
-40
lines changed

2 files changed

+156
-40
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 128 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,75 @@ class PyOpOperandIterator {
635635
MlirOpOperand opOperand;
636636
};
637637

638+
639+
640+
#if !defined(Py_GIL_DISABLED)
641+
inline void enableTryIncRef(nb::handle obj) noexcept { }
642+
inline bool tryIncRef(nb::handle obj) noexcept {
643+
if (Py_REFCNT(obj.ptr()) > 0) {
644+
Py_INCREF(obj.ptr());
645+
return true;
646+
}
647+
return false;
648+
}
649+
650+
#elif PY_VERSION_HEX >= 0x030E00A5
651+
652+
// CPython 3.14 provides an unstable API for these.
653+
inline void enableTryIncRef(nb::handle obj) noexcept {
654+
PyUnstable_EnableTryIncRef(obj.ptr());
655+
}
656+
inline bool tryIncRef(nb::handle obj) noexcept {
657+
return PyUnstable_TryIncRef(obj.ptr());
658+
}
659+
660+
#else
661+
662+
// For CPython 3.13 there is no API for this, and so we must implement our own.
663+
// This code originates from https://github.com/wjakob/nanobind/pull/865/files.
664+
void enableTryIncRef(nb::handle h) noexcept {
665+
// Since this is called during object construction, we know that we have
666+
// the only reference to the object and can use a non-atomic write.
667+
PyObject* obj = h.ptr();
668+
assert(h->ob_ref_shared == 0);
669+
h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
670+
}
671+
672+
bool tryIncRef(nb::handle h) noexcept {
673+
PyObject *obj = h.ptr();
674+
// See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
675+
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
676+
local += 1;
677+
if (local == 0) {
678+
// immortal
679+
return true;
680+
}
681+
if (_Py_IsOwnedByCurrentThread(obj)) {
682+
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
683+
#ifdef Py_REF_DEBUG
684+
_Py_INCREF_IncRefTotal();
685+
#endif
686+
return true;
687+
}
688+
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
689+
for (;;) {
690+
// If the shared refcount is zero and the object is either merged
691+
// or may not have weak references, then we cannot incref it.
692+
if (shared == 0 || shared == _Py_REF_MERGED) {
693+
return false;
694+
}
695+
696+
if (_Py_atomic_compare_exchange_ssize(
697+
&obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
698+
#ifdef Py_REF_DEBUG
699+
_Py_INCREF_IncRefTotal();
700+
#endif
701+
return true;
702+
}
703+
}
704+
}
705+
#endif
706+
638707
} // namespace
639708

640709
//------------------------------------------------------------------------------
@@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
706775
return liveOperations.size();
707776
}
708777

709-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
710-
std::vector<PyOperation *> liveObjects;
778+
std::vector<nb::object> PyMlirContext::getLiveOperationObjects() {
779+
std::vector<nb::object> liveObjects;
711780
nb::ft_lock_guard lock(liveOperationsMutex);
712-
for (auto &entry : liveOperations)
713-
liveObjects.push_back(entry.second.second);
781+
for (auto &entry : liveOperations) {
782+
// It is not safe to unconditionally increment the reference count here
783+
// because an operation that is in the process of being deleted by another
784+
// thread may still be present in the map.
785+
if (tryIncRef(entry.second.first)) {
786+
liveObjects.push_back(nb::steal(entry.second.first));
787+
}
788+
}
714789
return liveObjects;
715790
}
716791

@@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
720795
{
721796
nb::ft_lock_guard lock(liveOperationsMutex);
722797
std::swap(operations, liveOperations);
798+
for (auto &op : operations)
799+
op.second.second->setInvalidLocked();
723800
}
724-
for (auto &op : operations)
725-
op.second.second->setInvalid();
726801
size_t numInvalidated = operations.size();
727802
return numInvalidated;
728803
}
729804

730-
void PyMlirContext::clearOperation(MlirOperation op) {
731-
PyOperation *py_op;
732-
{
733-
nb::ft_lock_guard lock(liveOperationsMutex);
734-
auto it = liveOperations.find(op.ptr);
735-
if (it == liveOperations.end()) {
736-
return;
737-
}
738-
py_op = it->second.second;
739-
liveOperations.erase(it);
805+
void PyMlirContext::clearOperationLocked(MlirOperation op) {
806+
auto it = liveOperations.find(op.ptr);
807+
if (it == liveOperations.end()) {
808+
return;
740809
}
741-
py_op->setInvalid();
810+
PyOperation *py_op = it->second.second;
811+
py_op->setInvalidLocked();
812+
liveOperations.erase(it);
813+
}
814+
815+
void PyMlirContext::clearOperation(MlirOperation op) {
816+
nb::ft_lock_guard lock(liveOperationsMutex);
817+
clearOperationLocked(op);
742818
}
743819

744820
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -770,7 +846,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
770846
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771847
void *userData) {
772848
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
773-
contextRef->clearOperation(op);
849+
contextRef->clearOperationLocked(op);
774850
return MlirWalkResult::MlirWalkResultAdvance;
775851
};
776852
mlirOperationWalk(op.getOperation(), invalidatingCallback,
@@ -1203,19 +1279,23 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
12031279
: BaseContextObject(std::move(contextRef)), operation(operation) {}
12041280

12051281
PyOperation::~PyOperation() {
1282+
PyMlirContextRef context = getContext();
1283+
nb::ft_lock_guard lock(context->liveOperationsMutex);
12061284
// If the operation has already been invalidated there is nothing to do.
12071285
if (!valid)
12081286
return;
12091287

12101288
// Otherwise, invalidate the operation and remove it from live map when it is
12111289
// attached.
12121290
if (isAttached()) {
1213-
getContext()->clearOperation(*this);
1291+
// Since the operation was valid, we know that it is this object present
1292+
// in the map, not some other object.
1293+
context->liveOperations.erase(operation.ptr);
12141294
} else {
12151295
// And destroy it when it is detached, i.e. owned by Python, in which case
12161296
// all nested operations must be invalidated at removed from the live map as
12171297
// well.
1218-
erase();
1298+
eraseLocked();
12191299
}
12201300
}
12211301

@@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12411321
// Create.
12421322
PyOperationRef unownedOperation =
12431323
makeObjectRef<PyOperation>(std::move(contextRef), operation);
1324+
enableTryIncRef(unownedOperation.getObject());
12441325
unownedOperation->handle = unownedOperation.getObject();
12451326
if (parentKeepAlive) {
12461327
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
@@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12541335
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
12551336
auto &liveOperations = contextRef->liveOperations;
12561337
auto it = liveOperations.find(operation.ptr);
1257-
if (it == liveOperations.end()) {
1258-
// Create.
1259-
PyOperationRef result = createInstance(std::move(contextRef), operation,
1260-
std::move(parentKeepAlive));
1261-
liveOperations[operation.ptr] =
1262-
std::make_pair(result.getObject(), result.get());
1263-
return result;
1338+
if (it != liveOperations.end()) {
1339+
PyOperation *existing = it->second.second;
1340+
nb::handle pyRef = it->second.first;
1341+
1342+
// Try to increment the reference count of the existing entry. This can fail
1343+
// if the object is in the process of being destroyed by another thread.
1344+
if (tryIncRef(pyRef)) {
1345+
return PyOperationRef(existing, nb::steal<nb::object>(pyRef));
1346+
}
1347+
1348+
// Mark the existing entry as invalid, since we are about to replace it.
1349+
existing->valid = false;
12641350
}
1265-
// Use existing.
1266-
PyOperation *existing = it->second.second;
1267-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1268-
return PyOperationRef(existing, std::move(pyRef));
1351+
1352+
// Create a new wrapper object.
1353+
PyOperationRef result = createInstance(std::move(contextRef), operation,
1354+
std::move(parentKeepAlive));
1355+
liveOperations[operation.ptr] =
1356+
std::make_pair(result.getObject(), result.get());
1357+
return result;
12691358
}
12701359

12711360
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1297,6 +1386,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
12971386
}
12981387

12991388
void PyOperation::checkValid() const {
1389+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
13001390
if (!valid) {
13011391
throw std::runtime_error("the operation has been invalidated");
13021392
}
@@ -1638,12 +1728,17 @@ nb::object PyOperation::createOpView() {
16381728
return nb::cast(PyOpView(getRef().getObject()));
16391729
}
16401730

1641-
void PyOperation::erase() {
1731+
void PyOperation::eraseLocked() {
16421732
checkValid();
16431733
getContext()->clearOperationAndInside(*this);
16441734
mlirOperationDestroy(operation);
16451735
}
16461736

1737+
void PyOperation::erase() {
1738+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
1739+
eraseLocked();
1740+
}
1741+
16471742
namespace {
16481743
/// CRTP base class for Python MLIR values that subclass Value and should be
16491744
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2324,7 +2419,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
23242419
// The operation is also erased, so we must invalidate it. There may be Python
23252420
// references to this operation so we don't want to delete it from the list of
23262421
// live operations here.
2327-
symbol.getOperation().valid = false;
2422+
symbol.getOperation().setInvalid();
23282423
}
23292424

23302425
void PySymbolTable::dunderDel(const std::string &name) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class PyObjectRef {
8383
}
8484

8585
T *get() { return referrent; }
86-
T *operator->() {
86+
T *operator->() const {
8787
assert(referrent && object);
8888
return referrent;
8989
}
@@ -229,7 +229,7 @@ class PyMlirContext {
229229
static size_t getLiveCount();
230230

231231
/// Get a list of Python objects which are still in the live context map.
232-
std::vector<PyOperation *> getLiveOperationObjects();
232+
std::vector<nanobind::object> getLiveOperationObjects();
233233

234234
/// Gets the count of live operations associated with this context.
235235
/// Used for testing.
@@ -254,8 +254,9 @@ class PyMlirContext {
254254
void clearOperationsInside(PyOperationBase &op);
255255
void clearOperationsInside(MlirOperation op);
256256

257-
/// Clears the operaiton _and_ all operations inside using
258-
/// `clearOperation(MlirOperation)`.
257+
/// Clears the operation _and_ all operations inside using
258+
/// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is
259+
/// held.
259260
void clearOperationAndInside(PyOperationBase &op);
260261

261262
/// Gets the count of live modules associated with this context.
@@ -278,6 +279,9 @@ class PyMlirContext {
278279
struct ErrorCapture;
279280

280281
private:
282+
// Similar to clearOperation, but requires the liveOperations mutex to be held
283+
void clearOperationLocked(MlirOperation op);
284+
281285
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
282286
// preserving the relationship that an MlirContext maps to a single
283287
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -302,6 +306,9 @@ class PyMlirContext {
302306
// attempt to access it will raise an error.
303307
using LiveOperationMap =
304308
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
309+
310+
// liveOperationsMutex guards both liveOperations and the valid field of
311+
// PyOperation objects in free-threading mode.
305312
nanobind::ft_mutex liveOperationsMutex;
306313

307314
// Guarded by liveOperationsMutex in free-threading mode.
@@ -336,6 +343,7 @@ class BaseContextObject {
336343
}
337344

338345
/// Accesses the context reference.
346+
const PyMlirContextRef &getContext() const { return contextRef; }
339347
PyMlirContextRef &getContext() { return contextRef; }
340348

341349
private:
@@ -725,19 +733,29 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
725733
/// parent context's live operations map, and sets the valid bit false.
726734
void erase();
727735

728-
/// Invalidate the operation.
729-
void setInvalid() { valid = false; }
730-
731736
/// Clones this operation.
732737
nanobind::object clone(const nanobind::object &ip);
733738

739+
/// Invalidate the operation.
740+
void setInvalid() {
741+
nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex);
742+
setInvalidLocked();
743+
}
744+
/// Like setInvalid(), but requires the liveOperations mutex to be held.
745+
void setInvalidLocked() {
746+
valid = false;
747+
}
748+
734749
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
735750

736751
private:
737752
static PyOperationRef createInstance(PyMlirContextRef contextRef,
738753
MlirOperation operation,
739754
nanobind::object parentKeepAlive);
740755

756+
// Like erase(), but requires the caller to hold the liveOperationsMutex.
757+
void eraseLocked();
758+
741759
MlirOperation operation;
742760
nanobind::handle handle;
743761
// Keeps the parent alive, regardless of whether it is an Operation or
@@ -748,6 +766,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
748766
// ir_operation.py regarding testing corresponding lifetime guarantees.
749767
nanobind::object parentKeepAlive;
750768
bool attached = true;
769+
770+
// Guarded by 'context->liveOperationsMutex'. Valid objects must be present
771+
// in context->liveOperations.
751772
bool valid = true;
752773

753774
friend class PyOperationBase;

0 commit comments

Comments
 (0)