@@ -713,20 +713,21 @@ size_t PyMlirContext::clearLiveOperations() {
713713 return numInvalidated;
714714}
715715
716- void PyMlirContext::clearOperation (MlirOperation op) {
717- PyOperation *py_op;
718- {
719- nb::ft_lock_guard lock (liveOperationsMutex);
720- auto it = liveOperations.find (op.ptr );
721- if (it == liveOperations.end ()) {
722- return ;
723- }
724- py_op = it->second .second ;
725- liveOperations.erase (it);
716+ void PyMlirContext::_clearOperationLocked (MlirOperation op) {
717+ auto it = liveOperations.find (op.ptr );
718+ if (it == liveOperations.end ()) {
719+ return ;
726720 }
721+ PyOperation *py_op = it->second .second ;
722+ liveOperations.erase (it);
727723 py_op->setInvalid ();
728724}
729725
726+ void PyMlirContext::clearOperation (MlirOperation op) {
727+ nb::ft_lock_guard lock (liveOperationsMutex);
728+ _clearOperationLocked (op);
729+ }
730+
730731void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
731732 typedef struct {
732733 PyOperation &rootOp;
@@ -752,15 +753,30 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
752753 clearOperationsInside (opRef->getOperation ());
753754}
754755
756+ void _clearOperationAndInsideHelper (
757+ PyOperation &op, MlirOperationWalkCallback invalidatingCallback
758+ ) {
759+ mlirOperationWalk (op, invalidatingCallback, &op.getContext (), MlirWalkPreOrder);
760+ }
761+
762+ void PyMlirContext::_clearOperationAndInsideLocked (PyOperationBase &op) {
763+ MlirOperationWalkCallback invalidatingCallbackLocked = [](MlirOperation op,
764+ void *userData) {
765+ PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
766+ contextRef->_clearOperationLocked (op);
767+ return MlirWalkResult::MlirWalkResultAdvance;
768+ };
769+ _clearOperationAndInsideHelper (op.getOperation (), invalidatingCallbackLocked);
770+ }
771+
755772void PyMlirContext::clearOperationAndInside (PyOperationBase &op) {
756773 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
757774 void *userData) {
758775 PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
759776 contextRef->clearOperation (op);
760777 return MlirWalkResult::MlirWalkResultAdvance;
761778 };
762- mlirOperationWalk (op.getOperation (), invalidatingCallback,
763- &op.getOperation ().getContext (), MlirWalkPreOrder);
779+ _clearOperationAndInsideHelper (op.getOperation (), invalidatingCallback);
764780}
765781
766782size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
@@ -1189,19 +1205,25 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
11891205 : BaseContextObject(std::move(contextRef)), operation(operation) {}
11901206
11911207PyOperation::~PyOperation () {
1192- // If the operation has already been invalidated there is nothing to do.
1193- if (!valid)
1194- return ;
1208+ // This lock helps to serialize the access to ~PyOperation and PyOperation::forOperation
1209+ // when we should invalidate existing PyOperation
1210+ nb::ft_lock_guard lock (getContext ()->liveOperationsMutex );
1211+ {
1212+ nb::ft_lock_guard lock2 (opMutex);
1213+ if (!valid)
1214+ return ;
1215+ }
11951216
11961217 // Otherwise, invalidate the operation and remove it from live map when it is
11971218 // attached.
11981219 if (isAttached ()) {
1199- getContext ()->clearOperation (* this );
1220+ getContext ()->_clearOperationLocked (operation );
12001221 } else {
12011222 // And destroy it when it is detached, i.e. owned by Python, in which case
12021223 // all nested operations must be invalidated at removed from the live map as
12031224 // well.
1204- erase ();
1225+ getContext ()->_clearOperationAndInsideLocked (*this );
1226+ mlirOperationDestroy (operation);
12051227 }
12061228}
12071229
@@ -1234,6 +1256,41 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12341256 return unownedOperation;
12351257}
12361258
1259+
1260+ bool _nb_try_inc_ref (PyObject *obj) {
1261+ // See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
1262+ uint32_t local = _Py_atomic_load_uint32_relaxed (&obj->ob_ref_local );
1263+ local += 1 ;
1264+ if (local == 0 ) {
1265+ // immortal
1266+ return true ;
1267+ }
1268+ if (_Py_IsOwnedByCurrentThread (obj)) {
1269+ _Py_atomic_store_uint32_relaxed (&obj->ob_ref_local , local);
1270+ #ifdef Py_REF_DEBUG
1271+ _Py_INCREF_IncRefTotal ();
1272+ #endif
1273+ return true ;
1274+ }
1275+ Py_ssize_t shared = _Py_atomic_load_ssize_relaxed (&obj->ob_ref_shared );
1276+ for (;;) {
1277+ // If the shared refcount is zero and the object is either merged
1278+ // or may not have weak references, then we cannot incref it.
1279+ if (shared == 0 || shared == _Py_REF_MERGED) {
1280+ return false ;
1281+ }
1282+
1283+ if (_Py_atomic_compare_exchange_ssize (
1284+ &obj->ob_ref_shared , &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
1285+ #ifdef Py_REF_DEBUG
1286+ _Py_INCREF_IncRefTotal ();
1287+ #endif
1288+ return true ;
1289+ }
1290+ }
1291+ }
1292+
1293+
12371294PyOperationRef PyOperation::forOperation (PyMlirContextRef contextRef,
12381295 MlirOperation operation,
12391296 nb::object parentKeepAlive) {
@@ -1250,8 +1307,31 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12501307 }
12511308 // Use existing.
12521309 PyOperation *existing = it->second .second ;
1253- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1254- return PyOperationRef (existing, std::move (pyRef));
1310+ nb::object pyRef = nb::steal (it->second .first );
1311+
1312+ // Check whether pyRef is ongoing to be destroyed such that refcount increment
1313+ // wont keep it from deletion.
1314+ // If after incrementing the reference count its value is 1,
1315+ // it means that python object is under removal and ~PyOperation should be called.
1316+ // Thus, we should create new PyOperationRef.
1317+ if (_nb_try_inc_ref (pyRef.ptr ())) {
1318+ return PyOperationRef (existing, std::move (pyRef));
1319+ }
1320+
1321+ // We should lock first liveOperationsMutex and then opMutex.
1322+ // We need to use existing->opMutex to serialize the
1323+ // access to ~PyOperation and the code below
1324+ nb::ft_lock_guard lock2 (existing->opMutex );
1325+
1326+ // Invalidate existing
1327+ existing->valid = false ;
1328+
1329+ // Create.
1330+ PyOperationRef result = createInstance (std::move (contextRef), operation,
1331+ std::move (parentKeepAlive));
1332+ liveOperations[operation.ptr ] =
1333+ std::make_pair (result.getObject (), result.get ());
1334+ return result;
12551335}
12561336
12571337PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
@@ -1282,7 +1362,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
12821362 return PyOperation::createDetached (std::move (contextRef), op);
12831363}
12841364
1285- void PyOperation::checkValid () const {
1365+ void PyOperation::checkValid () {
1366+ nb::ft_lock_guard lock (opMutex);
12861367 if (!valid) {
12871368 throw std::runtime_error (" the operation has been invalidated" );
12881369 }
@@ -2305,6 +2386,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
23052386 // The operation is also erased, so we must invalidate it. There may be Python
23062387 // references to this operation so we don't want to delete it from the list of
23072388 // live operations here.
2389+ nb::ft_lock_guard lock (symbol.getOperation ().opMutex );
23082390 symbol.getOperation ().valid = false ;
23092391}
23102392
0 commit comments