Skip to content

Commit 3662311

Browse files
committed
ENH: mark the dragon4 scratch space as thread-local
1 parent 36b7ff9 commit 3662311

File tree

4 files changed

+24
-15
lines changed

4 files changed

+24
-15
lines changed

numpy/_core/src/multiarray/dragon4.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,11 @@ typedef struct {
163163
char repr[16384];
164164
} Dragon4_Scratch;
165165

166-
static int _bigint_static_in_use = 0;
167-
static Dragon4_Scratch _bigint_static;
166+
static NPY_TLS int _bigint_static_in_use = 0;
167+
static NPY_TLS Dragon4_Scratch _bigint_static;
168168

169169
static Dragon4_Scratch*
170170
get_dragon4_bigint_scratch(void) {
171-
/* this test+set is not threadsafe, but no matter because we have GIL */
172171
if (_bigint_static_in_use) {
173172
PyErr_SetString(PyExc_RuntimeError,
174173
"numpy float printing code is not re-entrant. "

numpy/_core/tests/test_arrayprint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
assert_, assert_equal, assert_raises, assert_warns, HAS_REFCOUNT,
1010
assert_raises_regex, IS_WASM
1111
)
12+
from numpy.testing._private.utils import run_threaded
1213
from numpy._core.arrayprint import _typelessdata
1314
import textwrap
1415

@@ -1249,3 +1250,10 @@ async def main():
12491250
loop = asyncio.new_event_loop()
12501251
asyncio.run(main())
12511252
loop.close()
1253+
1254+
@pytest.mark.skipif(IS_WASM, reason="wasm doesn't support threads")
1255+
def test_multithreaded_array_printing():
1256+
# the dragon4 implementation uses a static scratch space for performance
1257+
# reasons this test makes sure it is set up in a thread-safe manner
1258+
1259+
run_threaded(TestPrintOptions().test_floatmode, 500)

numpy/_core/tests/test_multithreading.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
1-
import concurrent.futures
21
import threading
32

43
import numpy as np
54
import pytest
65

76
from numpy.testing import IS_WASM
7+
from numpy.testing._private.utils import run_threaded
88

99
if IS_WASM:
1010
pytest.skip(allow_module_level=True, reason="no threading support in wasm")
1111

1212

13-
def run_threaded(func, iters, pass_count=False):
14-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
15-
if pass_count:
16-
futures = [tpe.submit(func, i) for i in range(iters)]
17-
else:
18-
futures = [tpe.submit(func) for _ in range(iters)]
19-
for f in futures:
20-
f.result()
21-
22-
2313
def test_parallel_randomstate_creation():
2414
# if the coercion cache is enabled and not thread-safe, creating
2515
# RandomState instances simultaneously leads to a data race

numpy/testing/_private/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from warnings import WarningMessage
1818
import pprint
1919
import sysconfig
20+
import concurrent.futures
2021

2122
import numpy as np
2223
from numpy._core import (
@@ -40,7 +41,7 @@
4041
'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare',
4142
'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON',
4243
'_OLD_PROMOTION', 'IS_MUSL', '_SUPPORTS_SVE', 'NOGIL_BUILD',
43-
'IS_EDITABLE'
44+
'IS_EDITABLE', 'run_threaded',
4445
]
4546

4647

@@ -2697,3 +2698,14 @@ def _get_glibc_version():
26972698

26982699
_glibcver = _get_glibc_version()
26992700
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)
2701+
2702+
2703+
def run_threaded(func, iters, pass_count=False):
2704+
"""Runs a function many times in parallel"""
2705+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
2706+
if pass_count:
2707+
futures = [tpe.submit(func, i) for i in range(iters)]
2708+
else:
2709+
futures = [tpe.submit(func) for _ in range(iters)]
2710+
for f in futures:
2711+
f.result()

0 commit comments

Comments
 (0)