Skip to content

Commit b31cf08

Browse files
committed
Use liveOperationsMutex in ~PyOperation and
lock first liveOperationsMutex and then opMutex
1 parent aace6a2 commit b31cf08

File tree

3 files changed

+153
-24
lines changed

3 files changed

+153
-24
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 102 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -713,20 +713,21 @@ size_t PyMlirContext::clearLiveOperations() {
713713
return numInvalidated;
714714
}
715715

716-
void PyMlirContext::clearOperation(MlirOperation op) {
717-
PyOperation *py_op;
718-
{
719-
nb::ft_lock_guard lock(liveOperationsMutex);
720-
auto it = liveOperations.find(op.ptr);
721-
if (it == liveOperations.end()) {
722-
return;
723-
}
724-
py_op = it->second.second;
725-
liveOperations.erase(it);
716+
void PyMlirContext::_clearOperationLocked(MlirOperation op) {
717+
auto it = liveOperations.find(op.ptr);
718+
if (it == liveOperations.end()) {
719+
return;
726720
}
721+
PyOperation *py_op = it->second.second;
722+
liveOperations.erase(it);
727723
py_op->setInvalid();
728724
}
729725

726+
void PyMlirContext::clearOperation(MlirOperation op) {
727+
nb::ft_lock_guard lock(liveOperationsMutex);
728+
_clearOperationLocked(op);
729+
}
730+
730731
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
731732
typedef struct {
732733
PyOperation &rootOp;
@@ -752,15 +753,30 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
752753
clearOperationsInside(opRef->getOperation());
753754
}
754755

756+
void _clearOperationAndInsideHelper(
757+
PyOperation &op, MlirOperationWalkCallback invalidatingCallback
758+
) {
759+
mlirOperationWalk(op, invalidatingCallback, &op.getContext(), MlirWalkPreOrder);
760+
}
761+
762+
void PyMlirContext::_clearOperationAndInsideLocked(PyOperationBase &op) {
763+
MlirOperationWalkCallback invalidatingCallbackLocked = [](MlirOperation op,
764+
void *userData) {
765+
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
766+
contextRef->_clearOperationLocked(op);
767+
return MlirWalkResult::MlirWalkResultAdvance;
768+
};
769+
_clearOperationAndInsideHelper(op.getOperation(), invalidatingCallbackLocked);
770+
}
771+
755772
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
756773
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
757774
void *userData) {
758775
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
759776
contextRef->clearOperation(op);
760777
return MlirWalkResult::MlirWalkResultAdvance;
761778
};
762-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
763-
&op.getOperation().getContext(), MlirWalkPreOrder);
779+
_clearOperationAndInsideHelper(op.getOperation(), invalidatingCallback);
764780
}
765781

766782
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
@@ -1189,19 +1205,25 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
11891205
: BaseContextObject(std::move(contextRef)), operation(operation) {}
11901206

11911207
PyOperation::~PyOperation() {
1192-
// If the operation has already been invalidated there is nothing to do.
1193-
if (!valid)
1194-
return;
1208+
// This lock helps to serialize the access to ~PyOperation and PyOperation::forOperation
1209+
// when we should invalidate existing PyOperation
1210+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
1211+
{
1212+
nb::ft_lock_guard lock2(opMutex);
1213+
if (!valid)
1214+
return;
1215+
}
11951216

11961217
// Otherwise, invalidate the operation and remove it from live map when it is
11971218
// attached.
11981219
if (isAttached()) {
1199-
getContext()->clearOperation(*this);
1220+
getContext()->_clearOperationLocked(operation);
12001221
} else {
12011222
// And destroy it when it is detached, i.e. owned by Python, in which case
12021223
// all nested operations must be invalidated at removed from the live map as
12031224
// well.
1204-
erase();
1225+
getContext()->_clearOperationAndInsideLocked(*this);
1226+
mlirOperationDestroy(operation);
12051227
}
12061228
}
12071229

@@ -1234,6 +1256,41 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12341256
return unownedOperation;
12351257
}
12361258

1259+
1260+
bool _nb_try_inc_ref(PyObject *obj) {
1261+
// See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
1262+
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
1263+
local += 1;
1264+
if (local == 0) {
1265+
// immortal
1266+
return true;
1267+
}
1268+
if (_Py_IsOwnedByCurrentThread(obj)) {
1269+
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
1270+
#ifdef Py_REF_DEBUG
1271+
_Py_INCREF_IncRefTotal();
1272+
#endif
1273+
return true;
1274+
}
1275+
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
1276+
for (;;) {
1277+
// If the shared refcount is zero and the object is either merged
1278+
// or may not have weak references, then we cannot incref it.
1279+
if (shared == 0 || shared == _Py_REF_MERGED) {
1280+
return false;
1281+
}
1282+
1283+
if (_Py_atomic_compare_exchange_ssize(
1284+
&obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
1285+
#ifdef Py_REF_DEBUG
1286+
_Py_INCREF_IncRefTotal();
1287+
#endif
1288+
return true;
1289+
}
1290+
}
1291+
}
1292+
1293+
12371294
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12381295
MlirOperation operation,
12391296
nb::object parentKeepAlive) {
@@ -1250,8 +1307,31 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12501307
}
12511308
// Use existing.
12521309
PyOperation *existing = it->second.second;
1253-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1254-
return PyOperationRef(existing, std::move(pyRef));
1310+
nb::object pyRef = nb::steal(it->second.first);
1311+
1312+
// Check whether pyRef is ongoing to be destroyed such that refcount increment
1313+
// wont keep it from deletion.
1314+
// If after incrementing the reference count its value is 1,
1315+
// it means that python object is under removal and ~PyOperation should be called.
1316+
// Thus, we should create new PyOperationRef.
1317+
if (_nb_try_inc_ref(pyRef.ptr())) {
1318+
return PyOperationRef(existing, std::move(pyRef));
1319+
}
1320+
1321+
// We should lock first liveOperationsMutex and then opMutex.
1322+
// We need to use existing->opMutex to serialize the
1323+
// access to ~PyOperation and the code below
1324+
nb::ft_lock_guard lock2(existing->opMutex);
1325+
1326+
// Invalidate existing
1327+
existing->valid = false;
1328+
1329+
// Create.
1330+
PyOperationRef result = createInstance(std::move(contextRef), operation,
1331+
std::move(parentKeepAlive));
1332+
liveOperations[operation.ptr] =
1333+
std::make_pair(result.getObject(), result.get());
1334+
return result;
12551335
}
12561336

12571337
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1282,7 +1362,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
12821362
return PyOperation::createDetached(std::move(contextRef), op);
12831363
}
12841364

1285-
void PyOperation::checkValid() const {
1365+
void PyOperation::checkValid() {
1366+
nb::ft_lock_guard lock(opMutex);
12861367
if (!valid) {
12871368
throw std::runtime_error("the operation has been invalidated");
12881369
}
@@ -2305,6 +2386,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
23052386
// The operation is also erased, so we must invalidate it. There may be Python
23062387
// references to this operation so we don't want to delete it from the list of
23072388
// live operations here.
2389+
nb::ft_lock_guard lock(symbol.getOperation().opMutex);
23082390
symbol.getOperation().valid = false;
23092391
}
23102392

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ class PyMlirContext {
253253
struct ErrorCapture;
254254

255255
private:
256+
void _clearOperationLocked(MlirOperation op);
257+
void _clearOperationAndInsideLocked(PyOperationBase &op);
258+
256259
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
257260
// preserving the relationship that an MlirContext maps to a single
258261
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -646,8 +649,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
646649
}
647650

648651
/// Gets the backing operation.
649-
operator MlirOperation() const { return get(); }
650-
MlirOperation get() const {
652+
operator MlirOperation() { return get(); }
653+
MlirOperation get() {
651654
checkValid();
652655
return operation;
653656
}
@@ -665,7 +668,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
665668
assert(attached && "operation already detached");
666669
attached = false;
667670
}
668-
void checkValid() const;
671+
void checkValid();
669672

670673
/// Gets the owning block or raises an exception if the operation has no
671674
/// owning block.
@@ -700,7 +703,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
700703
void erase();
701704

702705
/// Invalidate the operation.
703-
void setInvalid() { valid = false; }
706+
void setInvalid() {
707+
nanobind::ft_lock_guard lock(opMutex);
708+
valid = false;
709+
}
704710

705711
/// Clones this operation.
706712
nanobind::object clone(const nanobind::object &ip);
@@ -724,6 +730,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
724730
bool attached = true;
725731
bool valid = true;
726732

733+
nanobind::ft_mutex opMutex;
734+
727735
friend class PyOperationBase;
728736
friend class PySymbolTable;
729737
};

mlir/test/python/multithreaded_tests.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,45 @@ def _original_test_create_module_with_consts(self):
511511
with InsertionPoint(module.body), Location.name("c"):
512512
arith.constant(dtype, py_values[2])
513513

514+
def test_check_pyoperation_race(self):
515+
num_workers = 40
516+
num_runs = 20
517+
518+
barrier = threading.Barrier(num_workers)
519+
520+
def check_op(op):
521+
op_name = op.operation.name
522+
523+
def walk_operations(op):
524+
check_op(op)
525+
for region in op.operation.regions:
526+
for block in region:
527+
for op in block:
528+
walk_operations(op)
529+
530+
with Context():
531+
mlir_module = Module.parse(
532+
"""
533+
module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
534+
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
535+
return %arg0 : tensor<f32>
536+
}
537+
}
538+
"""
539+
)
540+
541+
def closure():
542+
barrier.wait()
543+
544+
for _ in range(num_runs):
545+
walk_operations(mlir_module)
546+
547+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
548+
futures = []
549+
for i in range(num_workers):
550+
futures.append(executor.submit(closure))
551+
assert len(list(f.result() for f in futures)) == num_workers
552+
514553

515554
if __name__ == "__main__":
516555
# Do not run the tests on CPython with GIL

0 commit comments

Comments
 (0)