29
29
#define _NpyHASH_XXROTATE (x ) ((x << 13) | (x >> 19)) /* Rotate left 13 bits */
30
30
#endif
31
31
32
+ #ifdef Py_GIL_DISABLED
33
+ // TODO: replace with PyMutex when it is public
34
+ #define LOCK_TABLE (tb ) \
35
+ if (!PyThread_acquire_lock(tb->mutex, NOWAIT_LOCK)) { \
36
+ PyThread_acquire_lock(tb->mutex, WAIT_LOCK); \
37
+ }
38
+ #define UNLOCK_TABLE (tb ) PyThread_release_lock(tb->mutex);
39
+ #define INITIALIZE_LOCK (tb ) \
40
+ tb->mutex = PyThread_allocate_lock(); \
41
+ if (tb->mutex == NULL) { \
42
+ PyErr_NoMemory(); \
43
+ PyMem_Free(res); \
44
+ return NULL; \
45
+ }
46
+ #define FREE_LOCK (tb ) \
47
+ if (tb->mutex != NULL) { \
48
+ PyThread_free_lock(tb->mutex); \
49
+ }
50
+ #else
51
+ // the GIL serializes access to the table so no need
52
+ // for locking if it is enabled
53
+ #define LOCK_TABLE (tb )
54
+ #define UNLOCK_TABLE (tb )
55
+ #define INITIALIZE_LOCK (tb )
56
+ #define FREE_LOCK (tb )
57
+ #endif
58
+
32
59
/*
33
60
* This hashing function is basically the Python tuple hash with the type
34
61
* identity hash inlined. The tuple hash itself is a reduced version of xxHash.
@@ -100,6 +127,8 @@ PyArrayIdentityHash_New(int key_len)
100
127
res -> size = 4 ; /* Start with a size of 4 */
101
128
res -> nelem = 0 ;
102
129
130
+ INITIALIZE_LOCK (res );
131
+
103
132
res -> buckets = PyMem_Calloc (4 * (key_len + 1 ), sizeof (PyObject * ));
104
133
if (res -> buckets == NULL ) {
105
134
PyErr_NoMemory ();
@@ -114,6 +143,7 @@ NPY_NO_EXPORT void
114
143
PyArrayIdentityHash_Dealloc (PyArrayIdentityHash * tb )
115
144
{
116
145
PyMem_Free (tb -> buckets );
146
+ FREE_LOCK (tb );
117
147
PyMem_Free (tb );
118
148
}
119
149
@@ -160,8 +190,9 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
160
190
for (npy_intp i = 0 ; i < prev_size ; i ++ ) {
161
191
PyObject * * item = & old_table [i * (tb -> key_len + 1 )];
162
192
if (item [0 ] != NULL ) {
163
- tb -> nelem -= 1 ; /* Decrement, setitem will increment again */
164
- PyArrayIdentityHash_SetItem (tb , item + 1 , item [0 ], 1 );
193
+ PyObject * * tb_item = find_item (tb , item + 1 );
194
+ tb_item [0 ] = item [0 ];
195
+ memcpy (tb_item + 1 , item + 1 , tb -> key_len * sizeof (PyObject * ));
165
196
}
166
197
}
167
198
PyMem_Free (old_table );
@@ -188,14 +219,17 @@ NPY_NO_EXPORT int
188
219
PyArrayIdentityHash_SetItem (PyArrayIdentityHash * tb ,
189
220
PyObject * const * key , PyObject * value , int replace )
190
221
{
222
+ LOCK_TABLE (tb );
191
223
if (value != NULL && _resize_if_necessary (tb ) < 0 ) {
192
224
/* Shrink, only if a new value is added. */
225
+ UNLOCK_TABLE (tb );
193
226
return -1 ;
194
227
}
195
228
196
229
PyObject * * tb_item = find_item (tb , key );
197
230
if (value != NULL ) {
198
231
if (tb_item [0 ] != NULL && !replace ) {
232
+ UNLOCK_TABLE (tb );
199
233
PyErr_SetString (PyExc_RuntimeError ,
200
234
"Identity cache already includes the item." );
201
235
return -1 ;
@@ -209,12 +243,16 @@ PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
209
243
memset (tb_item , 0 , (tb -> key_len + 1 ) * sizeof (PyObject * ));
210
244
}
211
245
246
+ UNLOCK_TABLE (tb );
212
247
return 0 ;
213
248
}
214
249
215
250
216
251
NPY_NO_EXPORT PyObject *
217
252
PyArrayIdentityHash_GetItem (PyArrayIdentityHash const * tb , PyObject * const * key )
218
253
{
219
- return find_item (tb , key )[0 ];
254
+ LOCK_TABLE (tb );
255
+ PyObject * res = find_item (tb , key )[0 ];
256
+ UNLOCK_TABLE (tb );
257
+ return res ;
220
258
}
0 commit comments