Skip to content

Commit 44d9943

Browse files
committed
Fix incorrect usage of mutex with identical dtypes
1 parent f5c800e commit 44d9943

File tree

4 files changed

+78
-73
lines changed

4 files changed

+78
-73
lines changed

stringdtype/stringdtype/src/casts.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,13 @@ string_to_string(PyArrayMethod_Context *context, char *const data[],
109109
out += out_stride;
110110
}
111111

112-
NPY_STRING_RELEASE_ALLOCATOR(idescr);
113-
NPY_STRING_RELEASE_ALLOCATOR(odescr);
112+
NPY_STRING_RELEASE_ALLOCATOR2(odescr, idescr);
114113

115114
return 0;
116115

117116
fail:
118117

119-
NPY_STRING_RELEASE_ALLOCATOR(idescr);
120-
NPY_STRING_RELEASE_ALLOCATOR(odescr);
118+
NPY_STRING_RELEASE_ALLOCATOR2(odescr, idescr);
121119

122120
return -1;
123121
}

stringdtype/stringdtype/src/dtype.h

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,9 @@
1919
#include "numpy/npy_math.h"
2020
#include "numpy/ufuncobject.h"
2121

22-
// from dtypemeta.h, not public in numpy
23-
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
24-
25-
#define NPY_STRING_ACQUIRE_ALLOCATOR(descr) \
26-
if (PyGILState_Check()) { \
27-
if (!PyThread_acquire_lock(descr->allocator_lock, NOWAIT_LOCK)) { \
28-
Py_BEGIN_ALLOW_THREADS PyThread_acquire_lock( \
29-
descr->allocator_lock, WAIT_LOCK); \
30-
Py_END_ALLOW_THREADS \
31-
} \
32-
} \
33-
else { \
34-
if (!PyThread_acquire_lock(descr->allocator_lock, NOWAIT_LOCK)) { \
35-
PyThread_acquire_lock(descr->allocator_lock, WAIT_LOCK); \
36-
} \
22+
#define NPY_STRING_ACQUIRE_ALLOCATOR(descr) \
23+
if (!PyThread_acquire_lock(descr->allocator_lock, NOWAIT_LOCK)) { \
24+
PyThread_acquire_lock(descr->allocator_lock, WAIT_LOCK); \
3725
}
3826

3927
#define NPY_STRING_ACQUIRE_ALLOCATOR2(descr1, descr2) \
@@ -53,6 +41,19 @@
5341

5442
#define NPY_STRING_RELEASE_ALLOCATOR(descr) \
5543
PyThread_release_lock(descr->allocator_lock);
44+
#define NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2) \
45+
NPY_STRING_RELEASE_ALLOCATOR(descr1); \
46+
if (descr1 != descr2) { \
47+
NPY_STRING_RELEASE_ALLOCATOR(descr2); \
48+
}
49+
#define NPY_STRING_RELEASE_ALLOCATOR3(descr1, descr2, descr3) \
50+
NPY_STRING_RELEASE_ALLOCATOR(descr1); \
51+
if (descr1 != descr2) { \
52+
NPY_STRING_RELEASE_ALLOCATOR(descr2); \
53+
} \
54+
if (descr1 != descr3 && descr2 != descr3) { \
55+
NPY_STRING_RELEASE_ALLOCATOR(descr3); \
56+
}
5657

5758
typedef struct {
5859
PyArray_Descr base;

stringdtype/stringdtype/src/umath.c

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,11 @@ multiply_resolve_descriptors(
156156
iin += i_stride; \
157157
out += o_stride; \
158158
} \
159-
NPY_STRING_RELEASE_ALLOCATOR(idescr); \
160-
NPY_STRING_RELEASE_ALLOCATOR(odescr); \
159+
NPY_STRING_RELEASE_ALLOCATOR2(idescr, odescr); \
161160
return 0; \
162161
\
163162
fail: \
164-
NPY_STRING_RELEASE_ALLOCATOR(idescr); \
165-
NPY_STRING_RELEASE_ALLOCATOR(odescr); \
163+
NPY_STRING_RELEASE_ALLOCATOR2(idescr, odescr); \
166164
return -1; \
167165
} \
168166
\
@@ -416,15 +414,11 @@ add_strided_loop(PyArrayMethod_Context *context, char *const data[],
416414
in2 += in2_stride;
417415
out += out_stride;
418416
}
419-
NPY_STRING_RELEASE_ALLOCATOR(s1descr);
420-
NPY_STRING_RELEASE_ALLOCATOR(s2descr);
421-
NPY_STRING_RELEASE_ALLOCATOR(odescr);
417+
NPY_STRING_RELEASE_ALLOCATOR3(s1descr, s2descr, odescr);
422418
return 0;
423419

424420
fail:
425-
NPY_STRING_RELEASE_ALLOCATOR(s1descr);
426-
NPY_STRING_RELEASE_ALLOCATOR(s2descr);
427-
NPY_STRING_RELEASE_ALLOCATOR(odescr);
421+
NPY_STRING_RELEASE_ALLOCATOR3(s1descr, s2descr, odescr);
428422
return -1;
429423
}
430424

@@ -476,15 +470,11 @@ maximum_strided_loop(PyArrayMethod_Context *context, char *const data[],
476470
out += out_stride;
477471
}
478472

479-
NPY_STRING_RELEASE_ALLOCATOR(in1_descr);
480-
NPY_STRING_RELEASE_ALLOCATOR(in2_descr);
481-
NPY_STRING_RELEASE_ALLOCATOR(out_descr);
473+
NPY_STRING_RELEASE_ALLOCATOR3(in1_descr, in2_descr, out_descr);
482474
return 0;
483475

484476
fail:
485-
NPY_STRING_RELEASE_ALLOCATOR(in1_descr);
486-
NPY_STRING_RELEASE_ALLOCATOR(in2_descr);
487-
NPY_STRING_RELEASE_ALLOCATOR(out_descr);
477+
NPY_STRING_RELEASE_ALLOCATOR3(in1_descr, in2_descr, out_descr);
488478
return -1;
489479
}
490480

@@ -536,15 +526,11 @@ minimum_strided_loop(PyArrayMethod_Context *context, char *const data[],
536526
out += out_stride;
537527
}
538528

539-
NPY_STRING_RELEASE_ALLOCATOR(in1_descr);
540-
NPY_STRING_RELEASE_ALLOCATOR(in2_descr);
541-
NPY_STRING_RELEASE_ALLOCATOR(out_descr);
529+
NPY_STRING_RELEASE_ALLOCATOR3(in1_descr, in2_descr, out_descr);
542530
return 0;
543531

544532
fail:
545-
NPY_STRING_RELEASE_ALLOCATOR(in1_descr);
546-
NPY_STRING_RELEASE_ALLOCATOR(in2_descr);
547-
NPY_STRING_RELEASE_ALLOCATOR(out_descr);
533+
NPY_STRING_RELEASE_ALLOCATOR3(in1_descr, in2_descr, out_descr);
548534
return -1;
549535
}
550536

@@ -618,14 +604,12 @@ string_equal_strided_loop(PyArrayMethod_Context *context, char *const data[],
618604
out += out_stride;
619605
}
620606

621-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
622-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
607+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
623608

624609
return 0;
625610

626611
fail:
627-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
628-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
612+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
629613

630614
return -1;
631615
}
@@ -700,14 +684,12 @@ string_not_equal_strided_loop(PyArrayMethod_Context *context,
700684
out += out_stride;
701685
}
702686

703-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
704-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
687+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
705688

706689
return 0;
707690

708691
fail:
709-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
710-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
692+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
711693

712694
return -1;
713695
}
@@ -779,14 +761,12 @@ string_greater_strided_loop(PyArrayMethod_Context *context, char *const data[],
779761
out += out_stride;
780762
}
781763

782-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
783-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
764+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
784765

785766
return 0;
786767

787768
fail:
788-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
789-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
769+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
790770

791771
return -1;
792772
}
@@ -860,14 +840,12 @@ string_greater_equal_strided_loop(PyArrayMethod_Context *context,
860840
out += out_stride;
861841
}
862842

863-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
864-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
843+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
865844

866845
return 0;
867846

868847
fail:
869-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
870-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
848+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
871849

872850
return -1;
873851
}
@@ -938,14 +916,12 @@ string_less_strided_loop(PyArrayMethod_Context *context, char *const data[],
938916
out += out_stride;
939917
}
940918

941-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
942-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
919+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
943920

944921
return 0;
945922

946923
fail:
947-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
948-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
924+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
949925

950926
return -1;
951927
}
@@ -1018,14 +994,12 @@ string_less_equal_strided_loop(PyArrayMethod_Context *context,
1018994
out += out_stride;
1019995
}
1020996

1021-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
1022-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
997+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
1023998

1024999
return 0;
10251000

10261001
fail:
1027-
NPY_STRING_RELEASE_ALLOCATOR(descr1);
1028-
NPY_STRING_RELEASE_ALLOCATOR(descr2);
1002+
NPY_STRING_RELEASE_ALLOCATOR2(descr1, descr2);
10291003

10301004
return -1;
10311005
}

stringdtype/tests/test_stringdtype.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ def string_list():
2020
return ["abc", "def", "ghi" * 10, "A¢☃€ 😊" * 100, "Abc" * 1000, "DEF"]
2121

2222

23+
@pytest.fixture
24+
def random_string_list():
25+
chars = list(string.ascii_letters + string.digits)
26+
chars = np.array(chars, dtype="U1")
27+
ret = np.random.choice(chars, size=100 * 1000, replace=True)
28+
return ret.view("U100")
29+
30+
2331
pd_param = pytest.param(
2432
pd_NA,
2533
marks=pytest.mark.skipif(pd_NA is None, reason="pandas is not installed"),
@@ -202,15 +210,12 @@ def test_unicode_casts(dtype, strings):
202210
)
203211

204212

205-
def test_additional_unicode_cast(dtype):
206-
RANDS_CHARS = np.array(
207-
list(string.ascii_letters + string.digits), dtype=(np.str_, 1)
208-
)
209-
arr = np.random.choice(RANDS_CHARS, size=10 * 100_000, replace=True).view(
210-
"U10"
211-
)
213+
def test_additional_unicode_cast(random_string_list, dtype):
214+
arr = np.array(random_string_list, dtype=dtype)
212215
np.testing.assert_array_equal(arr, arr.astype(dtype))
213-
np.testing.assert_array_equal(arr, arr.astype(dtype).astype("U10"))
216+
np.testing.assert_array_equal(
217+
arr, arr.astype(dtype).astype(random_string_list.dtype)
218+
)
214219

215220

216221
def test_insert_scalar(dtype, string_list):
@@ -671,7 +676,9 @@ def test_ufunc_multiply(dtype, string_list, other, other_dtype, use_out):
671676
result = [s * other for s in string_list]
672677

673678
if use_out:
679+
arr_cache = arr.copy()
674680
lres = np.multiply(arr, other, out=arr)
681+
arr[:] = arr_cache
675682
rres = np.multiply(other, arr, out=arr)
676683
else:
677684
lres = arr * other
@@ -793,3 +800,28 @@ def test_growing_strings(dtype):
793800
uarr = uarr + uarr
794801

795802
np.testing.assert_array_equal(arr, uarr)
803+
804+
805+
def test_threaded_access_and_mutation(dtype, random_string_list):
806+
# this test uses an RNG and may crash or cause deadlocks if there is a
807+
# threading bug
808+
rng = np.random.default_rng(0x4D3D3D3)
809+
810+
def func(arr):
811+
rnd = rng.random()
812+
# either write to random locations in the array, compute a ufunc, or
813+
# re-initialize the array
814+
if rnd < 0.3333:
815+
num = np.random.randint(0, arr.size)
816+
arr[num] = arr[num] + "hello"
817+
elif rnd < 0.6666:
818+
np.add(arr, arr)
819+
else:
820+
arr[:] = random_string_list
821+
822+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
823+
arr = np.array(random_string_list, dtype=dtype)
824+
futures = [tpe.submit(func, arr) for _ in range(100)]
825+
826+
for f in futures:
827+
f.result()

0 commit comments

Comments
 (0)