Skip to content

Commit bb391b5

Browse files
authored
Merge pull request numpy#27087 from ngoldbaum/fix-dragon4
ENH: mark the dragon4 scratch space as thread-local
2 parents 1082661 + f3337a3 commit bb391b5

File tree

4 files changed

+63
-86
lines changed

4 files changed

+63
-86
lines changed

numpy/_core/src/multiarray/dragon4.c

Lines changed: 41 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -163,28 +163,7 @@ typedef struct {
163163
char repr[16384];
164164
} Dragon4_Scratch;
165165

166-
static int _bigint_static_in_use = 0;
167-
static Dragon4_Scratch _bigint_static;
168-
169-
static Dragon4_Scratch*
170-
get_dragon4_bigint_scratch(void) {
171-
/* this test+set is not threadsafe, but no matter because we have GIL */
172-
if (_bigint_static_in_use) {
173-
PyErr_SetString(PyExc_RuntimeError,
174-
"numpy float printing code is not re-entrant. "
175-
"Ping the devs to fix it.");
176-
return NULL;
177-
}
178-
_bigint_static_in_use = 1;
179-
180-
/* in this dummy implementation we only return the static allocation */
181-
return &_bigint_static;
182-
}
183-
184-
static void
185-
free_dragon4_bigint_scratch(Dragon4_Scratch *mem){
186-
_bigint_static_in_use = 0;
187-
}
166+
static NPY_TLS Dragon4_Scratch _bigint_static;
188167

189168
/* Copy integer */
190169
static void
@@ -2210,11 +2189,11 @@ Format_floatbits(char *buffer, npy_uint32 bufferSize, BigInt *mantissa,
22102189
*/
22112190
static npy_uint32
22122191
Dragon4_PrintFloat_IEEE_binary16(
2213-
Dragon4_Scratch *scratch, npy_half *value, Dragon4_Options *opt)
2192+
npy_half *value, Dragon4_Options *opt)
22142193
{
2215-
char *buffer = scratch->repr;
2216-
const npy_uint32 bufferSize = sizeof(scratch->repr);
2217-
BigInt *bigints = scratch->bigints;
2194+
char *buffer = _bigint_static.repr;
2195+
const npy_uint32 bufferSize = sizeof(_bigint_static.repr);
2196+
BigInt *bigints = _bigint_static.bigints;
22182197

22192198
npy_uint16 val = *value;
22202199
npy_uint32 floatExponent, floatMantissa, floatSign;
@@ -2297,12 +2276,12 @@ Dragon4_PrintFloat_IEEE_binary16(
22972276
*/
22982277
static npy_uint32
22992278
Dragon4_PrintFloat_IEEE_binary32(
2300-
Dragon4_Scratch *scratch, npy_float32 *value,
2279+
npy_float32 *value,
23012280
Dragon4_Options *opt)
23022281
{
2303-
char *buffer = scratch->repr;
2304-
const npy_uint32 bufferSize = sizeof(scratch->repr);
2305-
BigInt *bigints = scratch->bigints;
2282+
char *buffer = _bigint_static.repr;
2283+
const npy_uint32 bufferSize = sizeof(_bigint_static.repr);
2284+
BigInt *bigints = _bigint_static.bigints;
23062285

23072286
union
23082287
{
@@ -2390,11 +2369,11 @@ Dragon4_PrintFloat_IEEE_binary32(
23902369
*/
23912370
static npy_uint32
23922371
Dragon4_PrintFloat_IEEE_binary64(
2393-
Dragon4_Scratch *scratch, npy_float64 *value, Dragon4_Options *opt)
2372+
npy_float64 *value, Dragon4_Options *opt)
23942373
{
2395-
char *buffer = scratch->repr;
2396-
const npy_uint32 bufferSize = sizeof(scratch->repr);
2397-
BigInt *bigints = scratch->bigints;
2374+
char *buffer = _bigint_static.repr;
2375+
const npy_uint32 bufferSize = sizeof(_bigint_static.repr);
2376+
BigInt *bigints = _bigint_static.bigints;
23982377

23992378
union
24002379
{
@@ -2505,11 +2484,11 @@ typedef struct FloatVal128 {
25052484
*/
25062485
static npy_uint32
25072486
Dragon4_PrintFloat_Intel_extended(
2508-
Dragon4_Scratch *scratch, FloatVal128 value, Dragon4_Options *opt)
2487+
FloatVal128 value, Dragon4_Options *opt)
25092488
{
2510-
char *buffer = scratch->repr;
2511-
const npy_uint32 bufferSize = sizeof(scratch->repr);
2512-
BigInt *bigints = scratch->bigints;
2489+
char *buffer = _bigint_static.repr;
2490+
const npy_uint32 bufferSize = sizeof(_bigint_static.repr);
2491+
BigInt *bigints = _bigint_static.bigints;
25132492

25142493
npy_uint32 floatExponent, floatSign;
25152494
npy_uint64 floatMantissa;
@@ -2603,7 +2582,7 @@ Dragon4_PrintFloat_Intel_extended(
26032582
*/
26042583
static npy_uint32
26052584
Dragon4_PrintFloat_Intel_extended80(
2606-
Dragon4_Scratch *scratch, npy_float80 *value, Dragon4_Options *opt)
2585+
npy_float80 *value, Dragon4_Options *opt)
26072586
{
26082587
FloatVal128 val128;
26092588
union {
@@ -2619,15 +2598,15 @@ Dragon4_PrintFloat_Intel_extended80(
26192598
val128.lo = buf80.integer.a;
26202599
val128.hi = buf80.integer.b;
26212600

2622-
return Dragon4_PrintFloat_Intel_extended(scratch, val128, opt);
2601+
return Dragon4_PrintFloat_Intel_extended(val128, opt);
26232602
}
26242603
#endif /* HAVE_LDOUBLE_INTEL_EXTENDED_10_BYTES_LE */
26252604

26262605
#ifdef HAVE_LDOUBLE_INTEL_EXTENDED_12_BYTES_LE
26272606
/* Intel's 80-bit IEEE extended precision format, 96-bit storage */
26282607
static npy_uint32
26292608
Dragon4_PrintFloat_Intel_extended96(
2630-
Dragon4_Scratch *scratch, npy_float96 *value, Dragon4_Options *opt)
2609+
npy_float96 *value, Dragon4_Options *opt)
26312610
{
26322611
FloatVal128 val128;
26332612
union {
@@ -2643,15 +2622,15 @@ Dragon4_PrintFloat_Intel_extended96(
26432622
val128.lo = buf96.integer.a;
26442623
val128.hi = buf96.integer.b;
26452624

2646-
return Dragon4_PrintFloat_Intel_extended(scratch, val128, opt);
2625+
return Dragon4_PrintFloat_Intel_extended(val128, opt);
26472626
}
26482627
#endif /* HAVE_LDOUBLE_INTEL_EXTENDED_12_BYTES_LE */
26492628

26502629
#ifdef HAVE_LDOUBLE_MOTOROLA_EXTENDED_12_BYTES_BE
26512630
/* Motorola Big-endian equivalent of the Intel-extended 96 fp format */
26522631
static npy_uint32
26532632
Dragon4_PrintFloat_Motorola_extended96(
2654-
Dragon4_Scratch *scratch, npy_float96 *value, Dragon4_Options *opt)
2633+
npy_float96 *value, Dragon4_Options *opt)
26552634
{
26562635
FloatVal128 val128;
26572636
union {
@@ -2668,7 +2647,7 @@ Dragon4_PrintFloat_Motorola_extended96(
26682647
val128.hi = buf96.integer.a >> 16;
26692648
/* once again we assume the int has same endianness as the float */
26702649

2671-
return Dragon4_PrintFloat_Intel_extended(scratch, val128, opt);
2650+
return Dragon4_PrintFloat_Intel_extended(val128, opt);
26722651
}
26732652
#endif /* HAVE_LDOUBLE_MOTOROLA_EXTENDED_12_BYTES_BE */
26742653

@@ -2688,7 +2667,7 @@ typedef union FloatUnion128
26882667
/* Intel's 80-bit IEEE extended precision format, 128-bit storage */
26892668
static npy_uint32
26902669
Dragon4_PrintFloat_Intel_extended128(
2691-
Dragon4_Scratch *scratch, npy_float128 *value, Dragon4_Options *opt)
2670+
npy_float128 *value, Dragon4_Options *opt)
26922671
{
26932672
FloatVal128 val128;
26942673
FloatUnion128 buf128;
@@ -2698,7 +2677,7 @@ Dragon4_PrintFloat_Intel_extended128(
26982677
val128.lo = buf128.integer.a;
26992678
val128.hi = buf128.integer.b;
27002679

2701-
return Dragon4_PrintFloat_Intel_extended(scratch, val128, opt);
2680+
return Dragon4_PrintFloat_Intel_extended(val128, opt);
27022681
}
27032682
#endif /* HAVE_LDOUBLE_INTEL_EXTENDED_16_BYTES_LE */
27042683

@@ -2717,11 +2696,11 @@ Dragon4_PrintFloat_Intel_extended128(
27172696
*/
27182697
static npy_uint32
27192698
Dragon4_PrintFloat_IEEE_binary128(
2720-
Dragon4_Scratch *scratch, FloatVal128 val128, Dragon4_Options *opt)
2699+
FloatVal128 val128, Dragon4_Options *opt)
27212700
{
2722-
char *buffer = scratch->repr;
2723-
const npy_uint32 bufferSize = sizeof(scratch->repr);
2724-
BigInt *bigints = scratch->bigints;
2701+
char *buffer = _bigint_static.repr;
2702+
const npy_uint32 bufferSize = sizeof(_bigint_static.repr);
2703+
BigInt *bigints = _bigint_static.bigints;
27252704

27262705
npy_uint32 floatExponent, floatSign;
27272706

@@ -2802,7 +2781,7 @@ Dragon4_PrintFloat_IEEE_binary128(
28022781
#if defined(HAVE_LDOUBLE_IEEE_QUAD_LE)
28032782
static npy_uint32
28042783
Dragon4_PrintFloat_IEEE_binary128_le(
2805-
Dragon4_Scratch *scratch, npy_float128 *value, Dragon4_Options *opt)
2784+
npy_float128 *value, Dragon4_Options *opt)
28062785
{
28072786
FloatVal128 val128;
28082787
FloatUnion128 buf128;
@@ -2811,7 +2790,7 @@ Dragon4_PrintFloat_IEEE_binary128_le(
28112790
val128.lo = buf128.integer.a;
28122791
val128.hi = buf128.integer.b;
28132792

2814-
return Dragon4_PrintFloat_IEEE_binary128(scratch, val128, opt);
2793+
return Dragon4_PrintFloat_IEEE_binary128(val128, opt);
28152794
}
28162795
#endif /* HAVE_LDOUBLE_IEEE_QUAD_LE */
28172796

@@ -2822,7 +2801,7 @@ Dragon4_PrintFloat_IEEE_binary128_le(
28222801
*/
28232802
static npy_uint32
28242803
Dragon4_PrintFloat_IEEE_binary128_be(
2825-
Dragon4_Scratch *scratch, npy_float128 *value, Dragon4_Options *opt)
2804+
npy_float128 *value, Dragon4_Options *opt)
28262805
{
28272806
FloatVal128 val128;
28282807
FloatUnion128 buf128;
@@ -2831,7 +2810,7 @@ Dragon4_PrintFloat_IEEE_binary128_be(
28312810
val128.lo = buf128.integer.b;
28322811
val128.hi = buf128.integer.a;
28332812

2834-
return Dragon4_PrintFloat_IEEE_binary128(scratch, val128, opt);
2813+
return Dragon4_PrintFloat_IEEE_binary128(val128, opt);
28352814
}
28362815
#endif /* HAVE_LDOUBLE_IEEE_QUAD_BE */
28372816

@@ -2877,11 +2856,11 @@ Dragon4_PrintFloat_IEEE_binary128_be(
28772856
*/
28782857
static npy_uint32
28792858
Dragon4_PrintFloat_IBM_double_double(
2880-
Dragon4_Scratch *scratch, npy_float128 *value, Dragon4_Options *opt)
2859+
npy_float128 *value, Dragon4_Options *opt)
28812860
{
2882-
char *buffer = scratch->repr;
2883-
const npy_uint32 bufferSize = sizeof(scratch->repr);
2884-
BigInt *bigints = scratch->bigints;
2861+
char *buffer = _bigint_static.repr;
2862+
const npy_uint32 bufferSize = sizeof(_bigint_static.repr);
2863+
BigInt *bigints = _bigint_static.bigints;
28852864

28862865
FloatVal128 val128;
28872866
FloatUnion128 buf128;
@@ -3068,16 +3047,10 @@ PyObject *\
30683047
Dragon4_Positional_##Type##_opt(npy_type *val, Dragon4_Options *opt)\
30693048
{\
30703049
PyObject *ret;\
3071-
Dragon4_Scratch *scratch = get_dragon4_bigint_scratch();\
3072-
if (scratch == NULL) {\
3073-
return NULL;\
3074-
}\
3075-
if (Dragon4_PrintFloat_##format(scratch, val, opt) < 0) {\
3076-
free_dragon4_bigint_scratch(scratch);\
3050+
if (Dragon4_PrintFloat_##format(val, opt) < 0) {\
30773051
return NULL;\
30783052
}\
3079-
ret = PyUnicode_FromString(scratch->repr);\
3080-
free_dragon4_bigint_scratch(scratch);\
3053+
ret = PyUnicode_FromString(_bigint_static.repr);\
30813054
return ret;\
30823055
}\
30833056
\
@@ -3106,16 +3079,10 @@ PyObject *\
31063079
Dragon4_Scientific_##Type##_opt(npy_type *val, Dragon4_Options *opt)\
31073080
{\
31083081
PyObject *ret;\
3109-
Dragon4_Scratch *scratch = get_dragon4_bigint_scratch();\
3110-
if (scratch == NULL) {\
3111-
return NULL;\
3112-
}\
3113-
if (Dragon4_PrintFloat_##format(scratch, val, opt) < 0) {\
3114-
free_dragon4_bigint_scratch(scratch);\
3082+
if (Dragon4_PrintFloat_##format(val, opt) < 0) { \
31153083
return NULL;\
31163084
}\
3117-
ret = PyUnicode_FromString(scratch->repr);\
3118-
free_dragon4_bigint_scratch(scratch);\
3085+
ret = PyUnicode_FromString(_bigint_static.repr);\
31193086
return ret;\
31203087
}\
31213088
PyObject *\

numpy/_core/tests/test_arrayprint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
assert_, assert_equal, assert_raises, assert_warns, HAS_REFCOUNT,
1010
assert_raises_regex, IS_WASM
1111
)
12+
from numpy.testing._private.utils import run_threaded
1213
from numpy._core.arrayprint import _typelessdata
1314
import textwrap
1415

@@ -1249,3 +1250,10 @@ async def main():
12491250
loop = asyncio.new_event_loop()
12501251
asyncio.run(main())
12511252
loop.close()
1253+
1254+
@pytest.mark.skipif(IS_WASM, reason="wasm doesn't support threads")
1255+
def test_multithreaded_array_printing():
1256+
# the dragon4 implementation uses a static scratch space for performance
1257+
# reasons this test makes sure it is set up in a thread-safe manner
1258+
1259+
run_threaded(TestPrintOptions().test_floatmode, 500)

numpy/_core/tests/test_multithreading.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
1-
import concurrent.futures
21
import threading
32

43
import numpy as np
54
import pytest
65

76
from numpy.testing import IS_WASM
7+
from numpy.testing._private.utils import run_threaded
88

99
if IS_WASM:
1010
pytest.skip(allow_module_level=True, reason="no threading support in wasm")
1111

1212

13-
def run_threaded(func, iters, pass_count=False):
14-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
15-
if pass_count:
16-
futures = [tpe.submit(func, i) for i in range(iters)]
17-
else:
18-
futures = [tpe.submit(func) for _ in range(iters)]
19-
for f in futures:
20-
f.result()
21-
22-
2313
def test_parallel_randomstate_creation():
2414
# if the coercion cache is enabled and not thread-safe, creating
2515
# RandomState instances simultaneously leads to a data race

numpy/testing/_private/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from warnings import WarningMessage
1818
import pprint
1919
import sysconfig
20+
import concurrent.futures
2021

2122
import numpy as np
2223
from numpy._core import (
@@ -40,7 +41,7 @@
4041
'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare',
4142
'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON',
4243
'_OLD_PROMOTION', 'IS_MUSL', '_SUPPORTS_SVE', 'NOGIL_BUILD',
43-
'IS_EDITABLE'
44+
'IS_EDITABLE', 'run_threaded',
4445
]
4546

4647

@@ -2697,3 +2698,14 @@ def _get_glibc_version():
26972698

26982699
_glibcver = _get_glibc_version()
26992700
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)
2701+
2702+
2703+
def run_threaded(func, iters, pass_count=False):
2704+
"""Runs a function many times in parallel"""
2705+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
2706+
if pass_count:
2707+
futures = [tpe.submit(func, i) for i in range(iters)]
2708+
else:
2709+
futures = [tpe.submit(func) for _ in range(iters)]
2710+
for f in futures:
2711+
f.result()

0 commit comments

Comments
 (0)