Skip to content

Commit 86e39a0

Browse files
committed
MNT: add locking for PyArrayIdentityHash
1 parent 6261524 commit 86e39a0

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

numpy/_core/src/common/npy_hashtable.c

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,33 @@
2929
#define _NpyHASH_XXROTATE(x) ((x << 13) | (x >> 19)) /* Rotate left 13 bits */
3030
#endif
3131

32+
#ifdef Py_GIL_DISABLED
33+
// TODO: replace with PyMutex when it is public
34+
#define LOCK_TABLE(tb) \
35+
if (!PyThread_acquire_lock(tb->mutex, NOWAIT_LOCK)) { \
36+
PyThread_acquire_lock(tb->mutex, WAIT_LOCK); \
37+
}
38+
#define UNLOCK_TABLE(tb) PyThread_release_lock(tb->mutex);
39+
#define INITIALIZE_LOCK(tb) \
40+
tb->mutex = PyThread_allocate_lock(); \
41+
if (tb->mutex == NULL) { \
42+
PyErr_NoMemory(); \
43+
PyMem_Free(res); \
44+
return NULL; \
45+
}
46+
#define FREE_LOCK(tb) \
47+
if (tb->mutex != NULL) { \
48+
PyThread_free_lock(tb->mutex); \
49+
}
50+
#else
51+
// the GIL serializes access to the table so no need
52+
// for locking if it is enabled
53+
#define LOCK_TABLE(tb)
54+
#define UNLOCK_TABLE(tb)
55+
#define INITIALIZE_LOCK(tb)
56+
#define FREE_LOCK(tb)
57+
#endif
58+
3259
/*
3360
* This hashing function is basically the Python tuple hash with the type
3461
* identity hash inlined. The tuple hash itself is a reduced version of xxHash.
@@ -100,6 +127,8 @@ PyArrayIdentityHash_New(int key_len)
100127
res->size = 4; /* Start with a size of 4 */
101128
res->nelem = 0;
102129

130+
INITIALIZE_LOCK(res);
131+
103132
res->buckets = PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
104133
if (res->buckets == NULL) {
105134
PyErr_NoMemory();
@@ -114,6 +143,7 @@ NPY_NO_EXPORT void
114143
PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb)
115144
{
116145
PyMem_Free(tb->buckets);
146+
FREE_LOCK(tb);
117147
PyMem_Free(tb);
118148
}
119149

@@ -160,8 +190,9 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
160190
for (npy_intp i = 0; i < prev_size; i++) {
161191
PyObject **item = &old_table[i * (tb->key_len + 1)];
162192
if (item[0] != NULL) {
163-
tb->nelem -= 1; /* Decrement, setitem will increment again */
164-
PyArrayIdentityHash_SetItem(tb, item+1, item[0], 1);
193+
PyObject **tb_item = find_item(tb, item + 1);
194+
tb_item[0] = item[0];
195+
memcpy(tb_item+1, item+1, tb->key_len * sizeof(PyObject *));
165196
}
166197
}
167198
PyMem_Free(old_table);
@@ -188,14 +219,17 @@ NPY_NO_EXPORT int
188219
PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
189220
PyObject *const *key, PyObject *value, int replace)
190221
{
222+
LOCK_TABLE(tb);
191223
if (value != NULL && _resize_if_necessary(tb) < 0) {
192224
/* Shrink, only if a new value is added. */
225+
UNLOCK_TABLE(tb);
193226
return -1;
194227
}
195228

196229
PyObject **tb_item = find_item(tb, key);
197230
if (value != NULL) {
198231
if (tb_item[0] != NULL && !replace) {
232+
UNLOCK_TABLE(tb);
199233
PyErr_SetString(PyExc_RuntimeError,
200234
"Identity cache already includes the item.");
201235
return -1;
@@ -209,12 +243,16 @@ PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
209243
memset(tb_item, 0, (tb->key_len + 1) * sizeof(PyObject *));
210244
}
211245

246+
UNLOCK_TABLE(tb);
212247
return 0;
213248
}
214249

215250

216251
NPY_NO_EXPORT PyObject *
217252
PyArrayIdentityHash_GetItem(PyArrayIdentityHash const *tb, PyObject *const *key)
218253
{
219-
return find_item(tb, key)[0];
254+
LOCK_TABLE(tb);
255+
PyObject *res = find_item(tb, key)[0];
256+
UNLOCK_TABLE(tb);
257+
return res;
220258
}

numpy/_core/src/common/npy_hashtable.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ typedef struct {
1313
PyObject **buckets;
1414
npy_intp size; /* current size */
1515
npy_intp nelem; /* number of elements */
16+
#ifdef Py_GIL_DISABLED
17+
PyThread_type_lock *mutex;
18+
#endif
1619
} PyArrayIdentityHash;
1720

1821

0 commit comments

Comments
 (0)