Skip to content

Commit ac214d2

Browse files
committed
MAINT,BUG: Fix ufunc.at to use new ufunc API
This also fixes a small issue that I forgot to include the special case for an unspecified output (or input): In this case matching is OK, so long the loop we pick can cast the operand. Previously, `ufunc.at` failed to check for floating point errors, this further adds the missing checks to match normal ufuncs.
1 parent 405c6ee commit ac214d2

File tree

5 files changed

+116
-39
lines changed

5 files changed

+116
-39
lines changed

numpy/core/src/multiarray/dtypemeta.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ typedef struct {
7474
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
7575
#define NPY_DT_SLOTS(dtype) ((NPY_DType_Slots *)(dtype)->dt_slots)
7676

77-
#define NPY_DT_is_legacy(dtype) ((dtype)->flags & NPY_DT_LEGACY)
78-
#define NPY_DT_is_abstract(dtype) ((dtype)->flags & NPY_DT_ABSTRACT)
79-
#define NPY_DT_is_parametric(dtype) ((dtype)->flags & NPY_DT_PARAMETRIC)
77+
#define NPY_DT_is_legacy(dtype) (((dtype)->flags & NPY_DT_LEGACY) != 0)
78+
#define NPY_DT_is_abstract(dtype) (((dtype)->flags & NPY_DT_ABSTRACT) != 0)
79+
#define NPY_DT_is_parametric(dtype) (((dtype)->flags & NPY_DT_PARAMETRIC) != 0)
8080

8181
/*
8282
* Macros for convenient classmethod calls, since these require

numpy/core/src/umath/dispatching.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ resolve_implementation_info(PyUFuncObject *ufunc,
193193
/* Unspecified out always matches (see below for inputs) */
194194
continue;
195195
}
196+
if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) {
197+
/* always matches */
198+
continue;
199+
}
196200
if (given_dtype == resolver_dtype) {
197201
continue;
198202
}

numpy/core/src/umath/ufunc_object.c

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5865,15 +5865,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58655865
PyArrayObject *op2_array = NULL;
58665866
PyArrayMapIterObject *iter = NULL;
58675867
PyArrayIterObject *iter2 = NULL;
5868-
PyArray_Descr *dtypes[3] = {NULL, NULL, NULL};
58695868
PyArrayObject *operands[3] = {NULL, NULL, NULL};
58705869
PyArrayObject *array_operands[3] = {NULL, NULL, NULL};
58715870

5872-
int needs_api = 0;
5871+
PyArray_DTypeMeta *signature[3] = {NULL, NULL, NULL};
5872+
PyArray_DTypeMeta *operand_DTypes[3] = {NULL, NULL, NULL};
5873+
PyArray_Descr *operation_descrs[3] = {NULL, NULL, NULL};
58735874

5874-
PyUFuncGenericFunction innerloop;
5875-
void *innerloopdata;
5876-
npy_intp i;
58775875
int nop;
58785876

58795877
/* override vars */
@@ -5886,6 +5884,10 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58865884
int buffersize;
58875885
int errormask = 0;
58885886
char * err_msg = NULL;
5887+
5888+
PyArrayMethod_StridedLoop *strided_loop;
5889+
NpyAuxData *auxdata = NULL;
5890+
58895891
NPY_BEGIN_THREADS_DEF;
58905892

58915893
if (ufunc->nin > 2) {
@@ -5973,26 +5975,51 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
59735975

59745976
/*
59755977
* Create dtypes array for either one or two input operands.
5976-
* The output operand is set to the first input operand
5978+
* Compare to the logic in `convert_ufunc_arguments`.
5979+
* TODO: It may be good to review some of this behaviour, since the
5980+
* operand array is special (it is written to) similar to reductions.
5981+
* Using unsafe-casting as done here, is likely not desirable.
59775982
*/
59785983
operands[0] = op1_array;
5984+
operand_DTypes[0] = NPY_DTYPE(PyArray_DESCR(op1_array));
5985+
Py_INCREF(operand_DTypes[0]);
5986+
int force_legacy_promotion = 0;
5987+
int allow_legacy_promotion = NPY_DT_is_legacy(operand_DTypes[0]);
5988+
59795989
if (op2_array != NULL) {
59805990
operands[1] = op2_array;
5981-
operands[2] = op1_array;
5991+
operand_DTypes[1] = NPY_DTYPE(PyArray_DESCR(op2_array));
5992+
Py_INCREF(operand_DTypes[1]);
5993+
allow_legacy_promotion &= NPY_DT_is_legacy(operand_DTypes[1]);
5994+
operands[2] = operands[0];
5995+
operand_DTypes[2] = operand_DTypes[0];
5996+
Py_INCREF(operand_DTypes[2]);
5997+
59825998
nop = 3;
5999+
if (allow_legacy_promotion && ((PyArray_NDIM(op1_array) == 0)
6000+
!= (PyArray_NDIM(op2_array) == 0))) {
6001+
/* both are legacy and only one is 0-D: force legacy */
6002+
force_legacy_promotion = should_use_min_scalar(2, operands, 0, NULL);
6003+
}
59836004
}
59846005
else {
5985-
operands[1] = op1_array;
6006+
operands[1] = operands[0];
6007+
operand_DTypes[1] = operand_DTypes[0];
6008+
Py_INCREF(operand_DTypes[1]);
59866009
operands[2] = NULL;
59876010
nop = 2;
59886011
}
59896012

5990-
if (ufunc->type_resolver(ufunc, NPY_UNSAFE_CASTING,
5991-
operands, NULL, dtypes) < 0) {
6013+
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
6014+
operands, signature, operand_DTypes,
6015+
force_legacy_promotion, allow_legacy_promotion);
6016+
if (ufuncimpl == NULL) {
59926017
goto fail;
59936018
}
5994-
if (ufunc->legacy_inner_loop_selector(ufunc, dtypes,
5995-
&innerloop, &innerloopdata, &needs_api) < 0) {
6019+
6020+
/* Find the correct descriptors for the operation */
6021+
if (resolve_descriptors(nop, ufunc, ufuncimpl,
6022+
operands, operation_descrs, signature, NPY_UNSAFE_CASTING) < 0) {
59966023
goto fail;
59976024
}
59986025

@@ -6053,21 +6080,44 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60536080
NPY_ITER_GROWINNER|
60546081
NPY_ITER_DELAY_BUFALLOC,
60556082
NPY_KEEPORDER, NPY_UNSAFE_CASTING,
6056-
op_flags, dtypes,
6083+
op_flags, operation_descrs,
60576084
-1, NULL, NULL, buffersize);
60586085

60596086
if (iter_buffer == NULL) {
60606087
goto fail;
60616088
}
60626089

6063-
needs_api = needs_api | NpyIter_IterationNeedsAPI(iter_buffer);
6064-
60656090
iternext = NpyIter_GetIterNext(iter_buffer, NULL);
60666091
if (iternext == NULL) {
60676092
NpyIter_Deallocate(iter_buffer);
60686093
goto fail;
60696094
}
60706095

6096+
PyArrayMethod_Context context = {
6097+
.caller = (PyObject *)ufunc,
6098+
.method = ufuncimpl,
6099+
.descriptors = operation_descrs,
6100+
};
6101+
6102+
NPY_ARRAYMETHOD_FLAGS flags;
6103+
/* Use contiguous strides; if there is such a loop it may be faster */
6104+
npy_intp strides[3] = {
6105+
operation_descrs[0]->elsize, operation_descrs[1]->elsize, 0};
6106+
if (nop == 3) {
6107+
strides[2] = operation_descrs[2]->elsize;
6108+
}
6109+
6110+
if (ufuncimpl->get_strided_loop(&context, 1, 0, strides,
6111+
&strided_loop, &auxdata, &flags) < 0) {
6112+
goto fail;
6113+
}
6114+
int needs_api = (flags & NPY_METH_REQUIRES_PYAPI) != 0;
6115+
needs_api |= NpyIter_IterationNeedsAPI(iter_buffer);
6116+
if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS)) {
6117+
/* Start with the floating-point exception flags cleared */
6118+
npy_clear_floatstatus_barrier((char*)&iter);
6119+
}
6120+
60716121
if (!needs_api) {
60726122
NPY_BEGIN_THREADS;
60736123
}
@@ -6076,14 +6126,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60766126
* Iterate over first and second operands and call ufunc
60776127
* for each pair of inputs
60786128
*/
6079-
i = iter->size;
6080-
while (i > 0)
6129+
int res = 0;
6130+
for (npy_intp i = iter->size; i > 0; i--)
60816131
{
60826132
char *dataptr[3];
60836133
char **buffer_dataptr;
60846134
/* one element at a time, no stride required but read by innerloop */
6085-
npy_intp count[3] = {1, 0xDEADBEEF, 0xDEADBEEF};
6086-
npy_intp stride[3] = {0xDEADBEEF, 0xDEADBEEF, 0xDEADBEEF};
6135+
npy_intp count = 1;
60876136

60886137
/*
60896138
* Set up data pointers for either one or two input operands.
@@ -6102,14 +6151,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61026151
/* Reset NpyIter data pointers which will trigger a buffer copy */
61036152
NpyIter_ResetBasePointers(iter_buffer, dataptr, &err_msg);
61046153
if (err_msg) {
6154+
res = -1;
61056155
break;
61066156
}
61076157

61086158
buffer_dataptr = NpyIter_GetDataPtrArray(iter_buffer);
61096159

6110-
innerloop(buffer_dataptr, count, stride, innerloopdata);
6111-
6112-
if (needs_api && PyErr_Occurred()) {
6160+
res = strided_loop(&context, buffer_dataptr, &count, strides, auxdata);
6161+
if (res != 0) {
61136162
break;
61146163
}
61156164

@@ -6123,32 +6172,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61236172
if (iter2 != NULL) {
61246173
PyArray_ITER_NEXT(iter2);
61256174
}
6126-
6127-
i--;
61286175
}
61296176

61306177
NPY_END_THREADS;
61316178

6132-
if (err_msg) {
6179+
if (res != 0 && err_msg) {
61336180
PyErr_SetString(PyExc_ValueError, err_msg);
61346181
}
6182+
if (res == 0 && !(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS)) {
6183+
/* NOTE: We could check float errors even when `res < 0` */
6184+
res = _check_ufunc_fperr(errormask, NULL, "at");
6185+
}
61356186

6187+
NPY_AUXDATA_FREE(auxdata);
61366188
NpyIter_Deallocate(iter_buffer);
61376189

61386190
Py_XDECREF(op2_array);
61396191
Py_XDECREF(iter);
61406192
Py_XDECREF(iter2);
6141-
for (i = 0; i < 3; i++) {
6142-
Py_XDECREF(dtypes[i]);
6193+
for (int i = 0; i < 3; i++) {
6194+
Py_XDECREF(operation_descrs[i]);
61436195
Py_XDECREF(array_operands[i]);
61446196
}
61456197

61466198
/*
6147-
* An error should only be possible if needs_api is true, but this is not
6148-
* strictly correct for old-style ufuncs (e.g. `power` released the GIL
6149-
* but manually set an Exception).
6199+
* An error should only be possible if needs_api is true or `res != 0`,
6200+
* but this is not strictly correct for old-style ufuncs
6201+
* (e.g. `power` released the GIL but manually set an Exception).
61506202
*/
6151-
if (PyErr_Occurred()) {
6203+
if (res != 0 || PyErr_Occurred()) {
61526204
return NULL;
61536205
}
61546206
else {
@@ -6163,10 +6215,11 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61636215
Py_XDECREF(op2_array);
61646216
Py_XDECREF(iter);
61656217
Py_XDECREF(iter2);
6166-
for (i = 0; i < 3; i++) {
6167-
Py_XDECREF(dtypes[i]);
6218+
for (int i = 0; i < 3; i++) {
6219+
Py_XDECREF(operation_descrs[i]);
61686220
Py_XDECREF(array_operands[i]);
61696221
}
6222+
NPY_AUXDATA_FREE(auxdata);
61706223

61716224
return NULL;
61726225
}

numpy/core/tests/test_custom_dtypes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,36 @@ def test_possible_and_impossible_reduce(self):
117117
match="the resolved dtypes are not compatible"):
118118
np.multiply.reduce(a)
119119

120+
def test_basic_ufunc_at(self):
121+
float_a = np.array([1., 2., 3.])
122+
b = self._get_array(2.)
123+
124+
float_b = b.view(np.float64).copy()
125+
np.multiply.at(float_b, [1, 1, 1], float_a)
126+
np.multiply.at(b, [1, 1, 1], float_a)
127+
128+
assert_array_equal(b.view(np.float64), float_b)
129+
120130
def test_basic_multiply_promotion(self):
121131
float_a = np.array([1., 2., 3.])
122132
b = self._get_array(2.)
123133

124134
res1 = float_a * b
125135
res2 = b * float_a
136+
126137
# one factor is one, so we get the factor of b:
127138
assert res1.dtype == res2.dtype == b.dtype
128139
expected_view = float_a * b.view(np.float64)
129140
assert_array_equal(res1.view(np.float64), expected_view)
130141
assert_array_equal(res2.view(np.float64), expected_view)
131142

143+
# Check that promotion works when `out` is used:
144+
np.multiply(b, float_a, out=res2)
145+
with pytest.raises(TypeError):
146+
# The promoter accepts this (maybe it should not), but the SFloat
147+
# result cannot be cast to integer:
148+
np.multiply(b, float_a, out=np.arange(3))
149+
132150
def test_basic_addition(self):
133151
a = self._get_array(2.)
134152
b = self._get_array(4.)

numpy/core/tests/test_ufunc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,14 +2396,16 @@ def test_reduce_casterrors(offset):
23962396

23972397
@pytest.mark.parametrize("method",
23982398
[np.add.accumulate, np.add.reduce,
2399-
pytest.param(lambda x: np.add.reduceat(x, [0]), id="reduceat")])
2400-
def test_reducelike_floaterrors(method):
2401-
# adding inf and -inf creates an invalid float and should give a warning
2399+
pytest.param(lambda x: np.add.reduceat(x, [0]), id="reduceat"),
2400+
pytest.param(lambda x: np.log.at(x, [2]), id="at")])
2401+
def test_ufunc_methods_floaterrors(method):
2402+
# adding inf and -inf (or log(-inf) creates an invalid float and warns
24022403
arr = np.array([np.inf, 0, -np.inf])
24032404
with np.errstate(all="warn"):
24042405
with pytest.warns(RuntimeWarning, match="invalid value"):
24052406
method(arr)
24062407

2408+
arr = np.array([np.inf, 0, -np.inf])
24072409
with np.errstate(all="raise"):
24082410
with pytest.raises(FloatingPointError):
24092411
method(arr)

0 commit comments

Comments
 (0)