Skip to content

Commit 05f8351

Browse files
authored
Merge pull request numpy#26348 from ngoldbaum/nogil-ufunc-caches
NOGIL: Make loop data cache and dispatch cache thread-safe in nogil build
2 parents 8a93bb0 + 86e39a0 commit 05f8351

File tree

4 files changed

+79
-14
lines changed

4 files changed

+79
-14
lines changed

numpy/_core/src/common/npy_hashtable.c

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,33 @@
2929
#define _NpyHASH_XXROTATE(x) ((x << 13) | (x >> 19)) /* Rotate left 13 bits */
3030
#endif
3131

32+
#ifdef Py_GIL_DISABLED
33+
// TODO: replace with PyMutex when it is public
34+
#define LOCK_TABLE(tb) \
35+
if (!PyThread_acquire_lock(tb->mutex, NOWAIT_LOCK)) { \
36+
PyThread_acquire_lock(tb->mutex, WAIT_LOCK); \
37+
}
38+
#define UNLOCK_TABLE(tb) PyThread_release_lock(tb->mutex);
39+
#define INITIALIZE_LOCK(tb) \
40+
tb->mutex = PyThread_allocate_lock(); \
41+
if (tb->mutex == NULL) { \
42+
PyErr_NoMemory(); \
43+
PyMem_Free(res); \
44+
return NULL; \
45+
}
46+
#define FREE_LOCK(tb) \
47+
if (tb->mutex != NULL) { \
48+
PyThread_free_lock(tb->mutex); \
49+
}
50+
#else
51+
// the GIL serializes access to the table so no need
52+
// for locking if it is enabled
53+
#define LOCK_TABLE(tb)
54+
#define UNLOCK_TABLE(tb)
55+
#define INITIALIZE_LOCK(tb)
56+
#define FREE_LOCK(tb)
57+
#endif
58+
3259
/*
3360
* This hashing function is basically the Python tuple hash with the type
3461
* identity hash inlined. The tuple hash itself is a reduced version of xxHash.
@@ -100,6 +127,8 @@ PyArrayIdentityHash_New(int key_len)
100127
res->size = 4; /* Start with a size of 4 */
101128
res->nelem = 0;
102129

130+
INITIALIZE_LOCK(res);
131+
103132
res->buckets = PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
104133
if (res->buckets == NULL) {
105134
PyErr_NoMemory();
@@ -114,6 +143,7 @@ NPY_NO_EXPORT void
114143
PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb)
115144
{
116145
PyMem_Free(tb->buckets);
146+
FREE_LOCK(tb);
117147
PyMem_Free(tb);
118148
}
119149

@@ -160,8 +190,9 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
160190
for (npy_intp i = 0; i < prev_size; i++) {
161191
PyObject **item = &old_table[i * (tb->key_len + 1)];
162192
if (item[0] != NULL) {
163-
tb->nelem -= 1; /* Decrement, setitem will increment again */
164-
PyArrayIdentityHash_SetItem(tb, item+1, item[0], 1);
193+
PyObject **tb_item = find_item(tb, item + 1);
194+
tb_item[0] = item[0];
195+
memcpy(tb_item+1, item+1, tb->key_len * sizeof(PyObject *));
165196
}
166197
}
167198
PyMem_Free(old_table);
@@ -188,14 +219,17 @@ NPY_NO_EXPORT int
188219
PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
189220
PyObject *const *key, PyObject *value, int replace)
190221
{
222+
LOCK_TABLE(tb);
191223
if (value != NULL && _resize_if_necessary(tb) < 0) {
192224
/* Shrink, only if a new value is added. */
225+
UNLOCK_TABLE(tb);
193226
return -1;
194227
}
195228

196229
PyObject **tb_item = find_item(tb, key);
197230
if (value != NULL) {
198231
if (tb_item[0] != NULL && !replace) {
232+
UNLOCK_TABLE(tb);
199233
PyErr_SetString(PyExc_RuntimeError,
200234
"Identity cache already includes the item.");
201235
return -1;
@@ -209,12 +243,16 @@ PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
209243
memset(tb_item, 0, (tb->key_len + 1) * sizeof(PyObject *));
210244
}
211245

246+
UNLOCK_TABLE(tb);
212247
return 0;
213248
}
214249

215250

216251
NPY_NO_EXPORT PyObject *
217252
PyArrayIdentityHash_GetItem(PyArrayIdentityHash const *tb, PyObject *const *key)
218253
{
219-
return find_item(tb, key)[0];
254+
LOCK_TABLE(tb);
255+
PyObject *res = find_item(tb, key)[0];
256+
UNLOCK_TABLE(tb);
257+
return res;
220258
}

numpy/_core/src/common/npy_hashtable.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ typedef struct {
1313
PyObject **buckets;
1414
npy_intp size; /* current size */
1515
npy_intp nelem; /* number of elements */
16+
#ifdef Py_GIL_DISABLED
17+
PyThread_type_lock *mutex;
18+
#endif
1619
} PyArrayIdentityHash;
1720

1821

numpy/_core/src/umath/legacy_array_method.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,37 +33,43 @@ typedef struct {
3333

3434

3535
/* Use a free list, since we should normally only need one at a time */
36+
#ifndef Py_GIL_DISABLED
3637
#define NPY_LOOP_DATA_CACHE_SIZE 5
3738
static int loop_data_num_cached = 0;
3839
static legacy_array_method_auxdata *loop_data_cache[NPY_LOOP_DATA_CACHE_SIZE];
39-
40+
#else
41+
#define NPY_LOOP_DATA_CACHE_SIZE 0
42+
#endif
4043

4144
static void
4245
legacy_array_method_auxdata_free(NpyAuxData *data)
4346
{
47+
#if NPY_LOOP_DATA_CACHE_SIZE > 0
4448
if (loop_data_num_cached < NPY_LOOP_DATA_CACHE_SIZE) {
4549
loop_data_cache[loop_data_num_cached] = (
4650
(legacy_array_method_auxdata *)data);
4751
loop_data_num_cached++;
4852
}
49-
else {
53+
else
54+
#endif
55+
{
5056
PyMem_Free(data);
5157
}
5258
}
5359

54-
#undef NPY_LOOP_DATA_CACHE_SIZE
55-
56-
5760
NpyAuxData *
5861
get_new_loop_data(
5962
PyUFuncGenericFunction loop, void *user_data, int pyerr_check)
6063
{
6164
legacy_array_method_auxdata *data;
65+
#if NPY_LOOP_DATA_CACHE_SIZE > 0
6266
if (NPY_LIKELY(loop_data_num_cached > 0)) {
6367
loop_data_num_cached--;
6468
data = loop_data_cache[loop_data_num_cached];
6569
}
66-
else {
70+
else
71+
#endif
72+
{
6773
data = PyMem_Malloc(sizeof(legacy_array_method_auxdata));
6874
if (data == NULL) {
6975
return NULL;
@@ -77,6 +83,7 @@ get_new_loop_data(
7783
return (NpyAuxData *)data;
7884
}
7985

86+
#undef NPY_LOOP_DATA_CACHE_SIZE
8087

8188
/*
8289
* This is a thin wrapper around the legacy loop signature.

numpy/_core/tests/test_multithreading.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,30 @@
99
pytest.skip(allow_module_level=True, reason="no threading support in wasm")
1010

1111

12-
def test_parallel_errstate_creation():
12+
def run_threaded(func, iters, pass_count=False):
13+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
14+
if pass_count:
15+
futures = [tpe.submit(func, i) for i in range(iters)]
16+
else:
17+
futures = [tpe.submit(func) for _ in range(iters)]
18+
for f in futures:
19+
f.result()
20+
21+
22+
def test_parallel_randomstate_creation():
1323
# if the coercion cache is enabled and not thread-safe, creating
1424
# RandomState instances simultaneously leads to a data race
1525
def func(seed):
1626
np.random.RandomState(seed)
1727

18-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
19-
futures = [tpe.submit(func, i) for i in range(500)]
20-
for f in futures:
21-
f.result()
28+
run_threaded(func, 500, pass_count=True)
29+
30+
def test_parallel_ufunc_execution():
31+
# if the loop data cache or dispatch cache are not thread-safe
32+
# computing ufuncs simultaneously in multiple threads leads
33+
# to a data race
34+
def func():
35+
arr = np.random.random((25,))
36+
np.isnan(arr)
37+
38+
run_threaded(func, 500)

0 commit comments

Comments
 (0)