Skip to content

Commit fb91abc

Browse files
sebergngoldbaum
andauthored
BUG: Fix repeat, accumulate for strings and accumulate API logic (numpy#27773)
* BUG: Fix repeat, accumulate for strings and accumulate API logic This fixes three (actually four) things: 1. `needs_api` was wrongly overwritten, this was just a line not deleted earlier. 2. Accumulate didn't allow string dtypes because it only allowed objects with references, this fixes it and adds a test. 3. `repeat` was just broken with string dtype... I guess there was no actual test. 4. `repeat` internals passed on `cast_info` not by reference. I guess that isn't a bug, but it's weird. Tests cover things relatively well, although things like GIL release being right is of course harder to test. Closes numpygh-27709 * DOC: Add small code comment about helping compiler * Update numpy/_core/src/umath/ufunc_object.c Co-authored-by: Nathan Goldbaum <[email protected]> --------- Co-authored-by: Nathan Goldbaum <[email protected]>
1 parent 9242c21 commit fb91abc

File tree

3 files changed

+96
-50
lines changed

3 files changed

+96
-50
lines changed

numpy/_core/src/multiarray/item_selection.c

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -785,21 +785,21 @@ static NPY_GCC_OPT_3 inline int
785785
npy_fastrepeat_impl(
786786
npy_intp n_outer, npy_intp n, npy_intp nel, npy_intp chunk,
787787
npy_bool broadcast, npy_intp* counts, char* new_data, char* old_data,
788-
npy_intp elsize, NPY_cast_info cast_info, int needs_refcounting)
788+
npy_intp elsize, NPY_cast_info *cast_info, int needs_custom_copy)
789789
{
790790
npy_intp i, j, k;
791791
for (i = 0; i < n_outer; i++) {
792792
for (j = 0; j < n; j++) {
793793
npy_intp tmp = broadcast ? counts[0] : counts[j];
794794
for (k = 0; k < tmp; k++) {
795-
if (!needs_refcounting) {
795+
if (!needs_custom_copy) {
796796
memcpy(new_data, old_data, chunk);
797797
}
798798
else {
799799
char *data[2] = {old_data, new_data};
800800
npy_intp strides[2] = {elsize, elsize};
801-
if (cast_info.func(&cast_info.context, data, &nel,
802-
strides, cast_info.auxdata) < 0) {
801+
if (cast_info->func(&cast_info->context, data, &nel,
802+
strides, cast_info->auxdata) < 0) {
803803
return -1;
804804
}
805805
}
@@ -811,48 +811,53 @@ npy_fastrepeat_impl(
811811
return 0;
812812
}
813813

814+
815+
/*
816+
* Helper to allow the compiler to specialize for all direct element copy
817+
* cases (e.g. all numerical dtypes).
818+
*/
814819
static NPY_GCC_OPT_3 int
815820
npy_fastrepeat(
816821
npy_intp n_outer, npy_intp n, npy_intp nel, npy_intp chunk,
817822
npy_bool broadcast, npy_intp* counts, char* new_data, char* old_data,
818-
npy_intp elsize, NPY_cast_info cast_info, int needs_refcounting)
823+
npy_intp elsize, NPY_cast_info *cast_info, int needs_custom_copy)
819824
{
820-
if (!needs_refcounting) {
825+
if (!needs_custom_copy) {
821826
if (chunk == 1) {
822827
return npy_fastrepeat_impl(
823828
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
824-
elsize, cast_info, needs_refcounting);
829+
elsize, cast_info, needs_custom_copy);
825830
}
826831
if (chunk == 2) {
827832
return npy_fastrepeat_impl(
828833
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
829-
elsize, cast_info, needs_refcounting);
834+
elsize, cast_info, needs_custom_copy);
830835
}
831836
if (chunk == 4) {
832837
return npy_fastrepeat_impl(
833838
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
834-
elsize, cast_info, needs_refcounting);
839+
elsize, cast_info, needs_custom_copy);
835840
}
836841
if (chunk == 8) {
837842
return npy_fastrepeat_impl(
838843
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
839-
elsize, cast_info, needs_refcounting);
844+
elsize, cast_info, needs_custom_copy);
840845
}
841846
if (chunk == 16) {
842847
return npy_fastrepeat_impl(
843848
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
844-
elsize, cast_info, needs_refcounting);
849+
elsize, cast_info, needs_custom_copy);
845850
}
846851
if (chunk == 32) {
847852
return npy_fastrepeat_impl(
848853
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
849-
elsize, cast_info, needs_refcounting);
854+
elsize, cast_info, needs_custom_copy);
850855
}
851856
}
852857

853858
return npy_fastrepeat_impl(
854859
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data, elsize,
855-
cast_info, needs_refcounting);
860+
cast_info, needs_custom_copy);
856861
}
857862

858863

@@ -872,7 +877,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
872877
char *new_data, *old_data;
873878
NPY_cast_info cast_info;
874879
NPY_ARRAYMETHOD_FLAGS flags;
875-
int needs_refcounting;
876880

877881
repeats = (PyArrayObject *)PyArray_ContiguousFromAny(op, NPY_INTP, 0, 1);
878882
if (repeats == NULL) {
@@ -897,7 +901,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
897901
aop = (PyArrayObject *)ap;
898902
n = PyArray_DIM(aop, axis);
899903
NPY_cast_info_init(&cast_info);
900-
needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(aop));
901904

902905
if (!broadcast && PyArray_SIZE(repeats) != n) {
903906
PyErr_Format(PyExc_ValueError,
@@ -947,16 +950,18 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
947950
n_outer *= PyArray_DIMS(aop)[i];
948951
}
949952

950-
if (needs_refcounting) {
953+
int needs_custom_copy = 0;
954+
if (PyDataType_REFCHK(PyArray_DESCR(ret))) {
955+
needs_custom_copy = 1;
951956
if (PyArray_GetDTypeTransferFunction(
952-
1, elsize, elsize, PyArray_DESCR(aop), PyArray_DESCR(aop), 0,
957+
1, elsize, elsize, PyArray_DESCR(aop), PyArray_DESCR(ret), 0,
953958
&cast_info, &flags) < 0) {
954959
goto fail;
955960
}
956961
}
957962

958963
if (npy_fastrepeat(n_outer, n, nel, chunk, broadcast, counts, new_data,
959-
old_data, elsize, cast_info, needs_refcounting) < 0) {
964+
old_data, elsize, &cast_info, needs_custom_copy) < 0) {
960965
goto fail;
961966
}
962967

numpy/_core/src/umath/ufunc_object.c

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,6 +2593,10 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
25932593
int idim, ndim;
25942594
int needs_api, need_outer_iterator;
25952595
int res = 0;
2596+
2597+
NPY_cast_info copy_info;
2598+
NPY_cast_info_init(&copy_info);
2599+
25962600
#if NPY_UF_DBG_TRACING
25972601
const char *ufunc_name = ufunc_get_name_cstr(ufunc);
25982602
#endif
@@ -2637,14 +2641,6 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
26372641
assert(PyArray_EquivTypes(descrs[0], descrs[1])
26382642
&& PyArray_EquivTypes(descrs[0], descrs[2]));
26392643

2640-
if (PyDataType_REFCHK(descrs[2]) && descrs[2]->type_num != NPY_OBJECT) {
2641-
/* This can be removed, but the initial element copy needs fixing */
2642-
PyErr_SetString(PyExc_TypeError,
2643-
"accumulation currently only supports `object` dtype with "
2644-
"references");
2645-
goto fail;
2646-
}
2647-
26482644
PyArrayMethod_Context context = {
26492645
.caller = (PyObject *)ufunc,
26502646
.method = ufuncimpl,
@@ -2740,10 +2736,10 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
27402736
else {
27412737
PyArray_Descr *dtype = descrs[0];
27422738
Py_INCREF(dtype);
2743-
op[0] = out = (PyArrayObject *)PyArray_NewFromDescr(
2739+
op[0] = out = (PyArrayObject *)PyArray_NewFromDescr_int(
27442740
&PyArray_Type, dtype,
27452741
ndim, PyArray_DIMS(op[1]), NULL, NULL,
2746-
0, NULL);
2742+
0, NULL, NULL, _NPY_ARRAY_ENSURE_DTYPE_IDENTITY);
27472743
if (out == NULL) {
27482744
goto fail;
27492745
}
@@ -2766,6 +2762,18 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
27662762
1, 0, fixed_strides, &strided_loop, &auxdata, &flags) < 0) {
27672763
goto fail;
27682764
}
2765+
/* Set up function to copy the first element if it has references */
2766+
if (PyDataType_REFCHK(descrs[2])) {
2767+
NPY_ARRAYMETHOD_FLAGS copy_flags;
2768+
/* Setup guarantees aligned here. */
2769+
if (PyArray_GetDTypeTransferFunction(
2770+
1, 0, 0, descrs[1], descrs[2], 0, &copy_info,
2771+
&copy_flags) == NPY_FAIL) {
2772+
goto fail;
2773+
}
2774+
flags = PyArrayMethod_COMBINED_FLAGS(flags, copy_flags);
2775+
}
2776+
27692777
needs_api = (flags & NPY_METH_REQUIRES_PYAPI) != 0;
27702778
if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS)) {
27712779
/* Start with the floating-point exception flags cleared */
@@ -2829,18 +2837,17 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
28292837
* Output (dataptr[0]) and input (dataptr[1]) may point to
28302838
* the same memory, e.g. np.add.accumulate(a, out=a).
28312839
*/
2832-
if (descrs[2]->type_num == NPY_OBJECT) {
2833-
/*
2834-
* Incref before decref to avoid the possibility of the
2835-
* reference count being zero temporarily.
2836-
*/
2837-
Py_XINCREF(*(PyObject **)dataptr_copy[1]);
2838-
Py_XDECREF(*(PyObject **)dataptr_copy[0]);
2839-
*(PyObject **)dataptr_copy[0] =
2840-
*(PyObject **)dataptr_copy[1];
2840+
if (copy_info.func) {
2841+
const npy_intp one = 1;
2842+
if (copy_info.func(
2843+
&copy_info.context, &dataptr_copy[1], &one,
2844+
&stride_copy[1], copy_info.auxdata) < 0) {
2845+
NPY_END_THREADS;
2846+
goto fail;
2847+
}
28412848
}
28422849
else {
2843-
memmove(dataptr_copy[0], dataptr_copy[1], itemsize);
2850+
memmove(dataptr_copy[2], dataptr_copy[1], itemsize);
28442851
}
28452852

28462853
if (count_m1 > 0) {
@@ -2889,18 +2896,17 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
28892896
* Output (dataptr[0]) and input (dataptr[1]) may point to the
28902897
* same memory, e.g. np.add.accumulate(a, out=a).
28912898
*/
2892-
if (descrs[2]->type_num == NPY_OBJECT) {
2893-
/*
2894-
* Incref before decref to avoid the possibility of the
2895-
* reference count being zero temporarily.
2896-
*/
2897-
Py_XINCREF(*(PyObject **)dataptr_copy[1]);
2898-
Py_XDECREF(*(PyObject **)dataptr_copy[0]);
2899-
*(PyObject **)dataptr_copy[0] =
2900-
*(PyObject **)dataptr_copy[1];
2899+
if (copy_info.func) {
2900+
const npy_intp one = 1;
2901+
const npy_intp strides[2] = {itemsize, itemsize};
2902+
if (copy_info.func(
2903+
&copy_info.context, &dataptr_copy[1], &one,
2904+
strides, copy_info.auxdata) < 0) {
2905+
goto fail;
2906+
}
29012907
}
29022908
else {
2903-
memmove(dataptr_copy[0], dataptr_copy[1], itemsize);
2909+
memmove(dataptr_copy[2], dataptr_copy[1], itemsize);
29042910
}
29052911

29062912
if (count > 1) {
@@ -2910,8 +2916,6 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
29102916

29112917
NPY_UF_DBG_PRINT1("iterator loop count %d\n", (int)count);
29122918

2913-
needs_api = PyDataType_REFCHK(descrs[0]);
2914-
29152919
if (!needs_api) {
29162920
NPY_BEGIN_THREADS_THRESHOLDED(count);
29172921
}
@@ -2925,6 +2929,7 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
29252929

29262930
finish:
29272931
NPY_AUXDATA_FREE(auxdata);
2932+
NPY_cast_info_xfree(&copy_info);
29282933
Py_DECREF(descrs[0]);
29292934
Py_DECREF(descrs[1]);
29302935
Py_DECREF(descrs[2]);
@@ -2949,6 +2954,8 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
29492954
Py_XDECREF(out);
29502955

29512956
NPY_AUXDATA_FREE(auxdata);
2957+
NPY_cast_info_xfree(&copy_info);
2958+
29522959
Py_XDECREF(descrs[0]);
29532960
Py_XDECREF(descrs[1]);
29542961
Py_XDECREF(descrs[2]);

numpy/_core/tests/test_stringdtype.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,40 @@ def test_unset_na_coercion():
15741574
arr == op
15751575

15761576

1577+
def test_repeat(string_array):
1578+
res = string_array.repeat(1000)
1579+
# Create an empty array with expanded dimension, and fill it. Then,
1580+
# reshape it to the expected result.
1581+
expected = np.empty_like(string_array, shape=string_array.shape + (1000,))
1582+
expected[...] = string_array[:, np.newaxis]
1583+
expected = expected.reshape(-1)
1584+
1585+
assert_array_equal(res, expected, strict=True)
1586+
1587+
1588+
@pytest.mark.parametrize("tile", [1, 6, (2, 5)])
1589+
def test_accumulation(string_array, tile):
1590+
"""Accumulation is odd for StringDType but tests dtypes with references.
1591+
"""
1592+
# Fill with mostly empty strings to not create absurdly big strings
1593+
arr = np.zeros_like(string_array, shape=(100,))
1594+
arr[:len(string_array)] = string_array
1595+
arr[-len(string_array):] = string_array
1596+
1597+
# Bloat size a bit (get above thresholds and test >1 ndim).
1598+
arr = np.tile(string_array, tile)
1599+
1600+
res = np.add.accumulate(arr, axis=0)
1601+
res_obj = np.add.accumulate(arr.astype(object), axis=0)
1602+
assert_array_equal(res, res_obj.astype(arr.dtype), strict=True)
1603+
1604+
if arr.ndim > 1:
1605+
res = np.add.accumulate(arr, axis=-1)
1606+
res_obj = np.add.accumulate(arr.astype(object), axis=-1)
1607+
1608+
assert_array_equal(res, res_obj.astype(arr.dtype), strict=True)
1609+
1610+
15771611
class TestImplementation:
15781612
"""Check that strings are stored in the arena when possible.
15791613

0 commit comments

Comments
 (0)