@@ -635,6 +635,75 @@ class PyOpOperandIterator {
635635 MlirOpOperand opOperand;
636636};
637637
638+
639+
640+ #if !defined(Py_GIL_DISABLED)
641+ inline void enableTryIncRef (nb::handle obj) noexcept { }
642+ inline bool tryIncRef (nb::handle obj) noexcept {
643+ if (Py_REFCNT (obj.ptr ()) > 0 ) {
644+ Py_INCREF (obj.ptr ());
645+ return true ;
646+ }
647+ return false ;
648+ }
649+
650+ #elif PY_VERSION_HEX >= 0x030E00A5
651+
652+ // CPython 3.14 provides an unstable API for these.
653+ inline void enableTryIncRef (nb::handle obj) noexcept {
654+ PyUnstable_EnableTryIncRef (obj.ptr ());
655+ }
656+ inline bool tryIncRef (nb::handle obj) noexcept {
657+ return PyUnstable_TryIncRef (obj.ptr ());
658+ }
659+
660+ #else
661+
662+ // For CPython 3.13 there is no API for this, and so we must implement our own.
663+ // This code originates from https://github.com/wjakob/nanobind/pull/865/files.
664+ void enableTryIncRef (nb::handle h) noexcept {
665+ // Since this is called during object construction, we know that we have
666+ // the only reference to the object and can use a non-atomic write.
667+ PyObject* obj = h.ptr ();
668+ assert (h->ob_ref_shared == 0 );
669+ h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
670+ }
671+
672+ bool tryIncRef (nb::handle h) noexcept {
673+ PyObject *obj = h.ptr ();
674+ // See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
675+ uint32_t local = _Py_atomic_load_uint32_relaxed (&obj->ob_ref_local );
676+ local += 1 ;
677+ if (local == 0 ) {
678+ // immortal
679+ return true ;
680+ }
681+ if (_Py_IsOwnedByCurrentThread (obj)) {
682+ _Py_atomic_store_uint32_relaxed (&obj->ob_ref_local , local);
683+ #ifdef Py_REF_DEBUG
684+ _Py_INCREF_IncRefTotal ();
685+ #endif
686+ return true ;
687+ }
688+ Py_ssize_t shared = _Py_atomic_load_ssize_relaxed (&obj->ob_ref_shared );
689+ for (;;) {
690+ // If the shared refcount is zero and the object is either merged
691+ // or may not have weak references, then we cannot incref it.
692+ if (shared == 0 || shared == _Py_REF_MERGED) {
693+ return false ;
694+ }
695+
696+ if (_Py_atomic_compare_exchange_ssize (
697+ &obj->ob_ref_shared , &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
698+ #ifdef Py_REF_DEBUG
699+ _Py_INCREF_IncRefTotal ();
700+ #endif
701+ return true ;
702+ }
703+ }
704+ }
705+ #endif
706+
638707} // namespace
639708
640709// ------------------------------------------------------------------------------
@@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
706775 return liveOperations.size ();
707776}
708777
709- std::vector<PyOperation * > PyMlirContext::getLiveOperationObjects () {
710- std::vector<PyOperation * > liveObjects;
778+ std::vector<nb::object > PyMlirContext::getLiveOperationObjects () {
779+ std::vector<nb::object > liveObjects;
711780 nb::ft_lock_guard lock (liveOperationsMutex);
712- for (auto &entry : liveOperations)
713- liveObjects.push_back (entry.second .second );
781+ for (auto &entry : liveOperations) {
782+ // It is not safe to unconditionally increment the reference count here
783+ // because an operation that is in the process of being deleted by another
784+ // thread may still be present in the map.
785+ if (tryIncRef (entry.second .first )) {
786+ liveObjects.push_back (nb::steal (entry.second .first ));
787+ }
788+ }
714789 return liveObjects;
715790}
716791
@@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
720795 {
721796 nb::ft_lock_guard lock (liveOperationsMutex);
722797 std::swap (operations, liveOperations);
798+ for (auto &op : operations)
799+ op.second .second ->setInvalidLocked ();
723800 }
724- for (auto &op : operations)
725- op.second .second ->setInvalid ();
726801 size_t numInvalidated = operations.size ();
727802 return numInvalidated;
728803}
729804
730- void PyMlirContext::clearOperation (MlirOperation op) {
731- PyOperation *py_op;
732- {
733- nb::ft_lock_guard lock (liveOperationsMutex);
734- auto it = liveOperations.find (op.ptr );
735- if (it == liveOperations.end ()) {
736- return ;
737- }
738- py_op = it->second .second ;
739- liveOperations.erase (it);
805+ void PyMlirContext::clearOperationLocked (MlirOperation op) {
806+ auto it = liveOperations.find (op.ptr );
807+ if (it == liveOperations.end ()) {
808+ return ;
740809 }
741- py_op->setInvalid ();
810+ PyOperation *py_op = it->second .second ;
811+ py_op->setInvalidLocked ();
812+ liveOperations.erase (it);
813+ }
814+
815+ void PyMlirContext::clearOperation (MlirOperation op) {
816+ nb::ft_lock_guard lock (liveOperationsMutex);
817+ clearOperationLocked (op);
742818}
743819
744820void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
@@ -770,7 +846,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
770846 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771847 void *userData) {
772848 PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
773- contextRef->clearOperation (op);
849+ contextRef->clearOperationLocked (op);
774850 return MlirWalkResult::MlirWalkResultAdvance;
775851 };
776852 mlirOperationWalk (op.getOperation (), invalidatingCallback,
@@ -1203,19 +1279,23 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
12031279 : BaseContextObject(std::move(contextRef)), operation(operation) {}
12041280
12051281PyOperation::~PyOperation () {
1282+ PyMlirContextRef context = getContext ();
1283+ nb::ft_lock_guard lock (context->liveOperationsMutex );
12061284 // If the operation has already been invalidated there is nothing to do.
12071285 if (!valid)
12081286 return ;
12091287
12101288 // Otherwise, invalidate the operation and remove it from live map when it is
12111289 // attached.
12121290 if (isAttached ()) {
1213- getContext ()->clearOperation (*this );
1291+ // Since the operation was valid, we know that it is this object present
1292+ // in the map, not some other object.
1293+ context->liveOperations .erase (operation.ptr );
12141294 } else {
12151295 // And destroy it when it is detached, i.e. owned by Python, in which case
12161296 // all nested operations must be invalidated at removed from the live map as
12171297 // well.
1218- erase ();
1298+ eraseLocked ();
12191299 }
12201300}
12211301
@@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12411321 // Create.
12421322 PyOperationRef unownedOperation =
12431323 makeObjectRef<PyOperation>(std::move (contextRef), operation);
1324+ enableTryIncRef (unownedOperation.getObject ());
12441325 unownedOperation->handle = unownedOperation.getObject ();
12451326 if (parentKeepAlive) {
12461327 unownedOperation->parentKeepAlive = std::move (parentKeepAlive);
@@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12541335 nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
12551336 auto &liveOperations = contextRef->liveOperations ;
12561337 auto it = liveOperations.find (operation.ptr );
1257- if (it == liveOperations.end ()) {
1258- // Create.
1259- PyOperationRef result = createInstance (std::move (contextRef), operation,
1260- std::move (parentKeepAlive));
1261- liveOperations[operation.ptr ] =
1262- std::make_pair (result.getObject (), result.get ());
1263- return result;
1338+ if (it != liveOperations.end ()) {
1339+ PyOperation *existing = it->second .second ;
1340+ nb::handle pyRef = it->second .first ;
1341+
1342+ // Try to increment the reference count of the existing entry. This can fail
1343+ // if the object is in the process of being destroyed by another thread.
1344+ if (tryIncRef (pyRef)) {
1345+ return PyOperationRef (existing, nb::steal<nb::object>(pyRef));
1346+ }
1347+
1348+ // Mark the existing entry as invalid, since we are about to replace it.
1349+ existing->valid = false ;
12641350 }
1265- // Use existing.
1266- PyOperation *existing = it->second .second ;
1267- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1268- return PyOperationRef (existing, std::move (pyRef));
1351+
1352+ // Create a new wrapper object.
1353+ PyOperationRef result = createInstance (std::move (contextRef), operation,
1354+ std::move (parentKeepAlive));
1355+ liveOperations[operation.ptr ] =
1356+ std::make_pair (result.getObject (), result.get ());
1357+ return result;
12691358}
12701359
12711360PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
@@ -1297,6 +1386,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
12971386}
12981387
12991388void PyOperation::checkValid () const {
1389+ nb::ft_lock_guard lock (getContext ()->liveOperationsMutex );
13001390 if (!valid) {
13011391 throw std::runtime_error (" the operation has been invalidated" );
13021392 }
@@ -1638,12 +1728,17 @@ nb::object PyOperation::createOpView() {
16381728 return nb::cast (PyOpView (getRef ().getObject ()));
16391729}
16401730
1641- void PyOperation::erase () {
1731+ void PyOperation::eraseLocked () {
16421732 checkValid ();
16431733 getContext ()->clearOperationAndInside (*this );
16441734 mlirOperationDestroy (operation);
16451735}
16461736
1737+ void PyOperation::erase () {
1738+ nb::ft_lock_guard lock (getContext ()->liveOperationsMutex );
1739+ eraseLocked ();
1740+ }
1741+
16471742namespace {
16481743// / CRTP base class for Python MLIR values that subclass Value and should be
16491744// / castable from it. The value hierarchy is one level deep and is not supposed
@@ -2324,7 +2419,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
23242419 // The operation is also erased, so we must invalidate it. There may be Python
23252420 // references to this operation so we don't want to delete it from the list of
23262421 // live operations here.
2327- symbol.getOperation ().valid = false ;
2422+ symbol.getOperation ().setInvalid () ;
23282423}
23292424
23302425void PySymbolTable::dunderDel (const std::string &name) {
0 commit comments