Skip to content

Commit ccff7fb

Browse files
authored
BUG: allow replacement in the dispatch cache (numpy#26693)
Allow identical replacement in dispatch cache, since this can be hit with freethreaded Python * TST: add a failing test for dispatch cache thread safety * TST: test replace=True error with new semantics First version added a lock, but maybe some information is interesting: The test code I added raises an unhandled thread exception every time I run it using a linux x86 laptop without the change to dispatching.c in this PR. Still not sure why this failure is only hit under gcc. Here's what I think is happening: right now the locking in the dispatch cache is only internal to the cache itself. The current locking strategy allows a race condition where two threads simultaneously see the cache is empty early in dispatch and then try fill it. The second thread to fill it sees the cache is filled and raises an exception, because the replace parameter for PyArrayIdentityHash_SetItem is 0, so replacements raise an exception. I don't think it's possible to support this replace feature without moving the lock from the dispatch cache struct to somewhere in the dispatching logic in dispatching.c. We'd need to lock around all the spots we check for an entry in the dispatch cache and then later insert an entry into the cache. Happy to try that approach if it turns out replacing entries in this cache is problematic for some reason. I didn't want to do that since this code is hit every time a ufunc is called so I don't want to add even larger blocks of code that need to be locked around. In practice, I don't think it's problematic to simply replace entries when this happens, at least not any more problematic than the current approach, since the dispatch cache holds borrowed references to ArrayMethod instances. Fixes numpy#26690.
1 parent b8d1012 commit ccff7fb

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

numpy/_core/src/common/npy_hashtable.c

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,13 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
210210
* @param value Normally a Python object, no reference counting is done.
211211
* use NULL to clear an item. If the item does not exist, no
212212
* action is performed for NULL.
213-
* @param replace If 1, allow replacements.
213+
* @param replace If 1, allow replacements. If replace is 0 an error is raised
214+
* if the stored value is different from the value to be cached. If the
215+
* value to be cached is identical to the stored value, the value to be
216+
* cached is ignored and no error is raised.
214217
* @returns 0 on success, -1 with a MemoryError or RuntimeError (if an item
215-
* is added which is already in the cache). The caller should avoid
216-
* the RuntimeError.
218+
* is added which is already in the cache and replace is 0). The
219+
* caller should avoid the RuntimeError.
217220
*/
218221
NPY_NO_EXPORT int
219222
PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
@@ -228,10 +231,10 @@ PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
228231

229232
PyObject **tb_item = find_item(tb, key);
230233
if (value != NULL) {
231-
if (tb_item[0] != NULL && !replace) {
234+
if (tb_item[0] != NULL && tb_item[0] != value && !replace) {
232235
UNLOCK_TABLE(tb);
233236
PyErr_SetString(PyExc_RuntimeError,
234-
"Identity cache already includes the item.");
237+
"Identity cache already includes an item with this key.");
235238
return -1;
236239
}
237240
tb_item[0] = value;

numpy/_core/tests/test_hashtable.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ def test_identity_hashtable(key_length, length):
2424
res = identityhash_tester(key_length, keys_vals, replace=True)
2525
assert res is expected
2626

27-
# check that ensuring one duplicate definitely raises:
28-
keys_vals.insert(0, keys_vals[-2])
27+
if length == 1:
28+
return
29+
30+
# add a new item with a key that is already used and a new value, this
31+
# should error if replace is False, see gh-26690
32+
new_key = (keys_vals[1][0], object())
33+
keys_vals[0] = new_key
2934
with pytest.raises(RuntimeError):
3035
identityhash_tester(key_length, keys_vals)

numpy/_core/tests/test_multithreading.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import concurrent.futures
2+
import threading
23

34
import numpy as np
45
import pytest
@@ -30,13 +31,29 @@ def func(seed):
3031
def test_parallel_ufunc_execution():
3132
# if the loop data cache or dispatch cache are not thread-safe
3233
# computing ufuncs simultaneously in multiple threads leads
33-
# to a data race
34+
# to a data race that causes crashes or spurious exceptions
3435
def func():
3536
arr = np.random.random((25,))
3637
np.isnan(arr)
3738

3839
run_threaded(func, 500)
3940

41+
# see gh-26690
42+
NUM_THREADS = 50
43+
44+
b = threading.Barrier(NUM_THREADS)
45+
46+
a = np.ones(1000)
47+
48+
def f():
49+
b.wait()
50+
return a.sum()
51+
52+
threads = [threading.Thread(target=f) for _ in range(NUM_THREADS)]
53+
54+
[t.start() for t in threads]
55+
[t.join() for t in threads]
56+
4057
def test_temp_elision_thread_safety():
4158
amid = np.ones(50000)
4259
bmid = np.ones(50000)

0 commit comments

Comments
 (0)