Skip to content

Commit 13cd98a

Browse files
committed
MAINT: Fixups and new tests based on Marten's reviews
More fixups are coming, the biggest change is that the error message is now improved when a reduction makes no sense such as `np.subtract.reduce(np.array([1, 2, 3], dtype="M8[s]"))` where input and output cannot have the same descriptor. (Some more fixups still to go)
1 parent 5842105 commit 13cd98a

File tree

4 files changed

+52
-27
lines changed

4 files changed

+52
-27
lines changed

numpy/core/src/umath/reduction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ typedef int (PyArray_AssignReduceIdentityFunc)(PyArrayObject *result,
2020

2121
/*
2222
* Inner definition of the reduce loop, only used for a static function.
23-
* At some point around NmPy 1.6, there was probably an intention to make
23+
* At some point around NumPy 1.6, there was probably an intention to make
2424
* the reduce loop customizable at this level (per ufunc?).
2525
*
2626
* TODO: This should be refactored/removed.

numpy/core/src/umath/ufunc_object.c

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,9 +2485,9 @@ PyUFunc_GeneralizedFunctionInternal(PyUFuncObject *ufunc,
24852485

24862486
/* Final preparation of the arraymethod call */
24872487
PyArrayMethod_Context context = {
2488-
.caller = (PyObject *)ufunc,
2489-
.method = ufuncimpl,
2490-
.descriptors = operation_descrs,
2488+
.caller = (PyObject *)ufunc,
2489+
.method = ufuncimpl,
2490+
.descriptors = operation_descrs,
24912491
};
24922492
PyArrayMethod_StridedLoop *strided_loop;
24932493
NPY_ARRAYMETHOD_FLAGS flags = 0;
@@ -2607,9 +2607,9 @@ PyUFunc_GenericFunctionInternal(PyUFuncObject *ufunc,
26072607

26082608
/* Final preparation of the arraymethod call */
26092609
PyArrayMethod_Context context = {
2610-
.caller = (PyObject *)ufunc,
2611-
.method = ufuncimpl,
2612-
.descriptors = operation_descrs,
2610+
.caller = (PyObject *)ufunc,
2611+
.method = ufuncimpl,
2612+
.descriptors = operation_descrs,
26132613
};
26142614

26152615
/* Do the ufunc loop */
@@ -2679,13 +2679,17 @@ PyUFunc_GenericFunction(PyUFuncObject *NPY_UNUSED(ufunc),
26792679
* Promote and resolve a reduction like operation.
26802680
*
26812681
* @param ufunc
2682-
* @param arr The operation array (out was never used)
2682+
* @param arr The operation array
2683+
* @param out The output array or NULL if not provided. Note that NumPy always
2684+
* used out to mean the same as `dtype=out.dtype` and never passed
2685+
* the array itself to the type-resolution.
26832686
* @param signature The DType signature, which may already be set due to the
26842687
* dtype passed in by the user, or the special cases (add, multiply).
26852688
* (Contains strong references and may be modified.)
2686-
* @param enforce_uniform_args If `1` fully uniform dtypes/descriptors are
2687-
* enforced as required for accumulate and (currently) reduceat.
2689+
* @param enforce_uniform_args If `NPY_TRUE` fully uniform dtypes/descriptors
2690+
* are enforced as required for accumulate and (currently) reduceat.
26882691
* @param out_descrs New references to the resolved descriptors (on success).
2692+
* @param method The ufunc method, "reduce", "reduceat", or "accumulate".
26892693
26902694
* @returns ufuncimpl The `ArrayMethod` implemention to use. Or NULL if an
26912695
* error occurred.
@@ -2694,7 +2698,8 @@ static PyArrayMethodObject *
26942698
reducelike_promote_and_resolve(PyUFuncObject *ufunc,
26952699
PyArrayObject *arr, PyArrayObject *out,
26962700
PyArray_DTypeMeta *signature[3],
2697-
int enforce_uniform_args, PyArray_Descr *out_descrs[3])
2701+
npy_bool enforce_uniform_args, PyArray_Descr *out_descrs[3],
2702+
char *method)
26982703
{
26992704
/*
27002705
* Note that the `ops` is not realy correct. But legacy resolution
@@ -2731,22 +2736,28 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc,
27312736
return NULL;
27322737
}
27332738

2734-
/* Find the correct descriptors for the operation */
2739+
/*
2740+
* Find the correct descriptors for the operation. We use unsafe casting
2741+
* for historic reasons: The logic ufuncs required it to cast everything to
2742+
* boolean. However, we now special case the logical ufuncs, so that the
2743+
* casting safety could in principle be set to the default same-kind.
2744+
* (although this should possibly happen through a deprecation)
2745+
*/
27352746
if (resolve_descriptors(3, ufunc, ufuncimpl,
27362747
ops, out_descrs, signature, NPY_UNSAFE_CASTING) < 0) {
27372748
return NULL;
27382749
}
27392750

27402751
/*
2741-
* The first operand and output be the same array, so they should
2752+
* The first operand and output should be the same array, so they should
27422753
* be identical. The second argument can be different for reductions,
27432754
* but is checked to be identical for accumulate and reduceat.
27442755
*/
27452756
if (out_descrs[0] != out_descrs[2] || (
27462757
enforce_uniform_args && out_descrs[0] != out_descrs[1])) {
27472758
PyErr_Format(PyExc_TypeError,
2748-
"The resolved dtypes are not compatible with an accumulate "
2749-
"or reduceat loop.");
2759+
"the resolved dtypes are not compatible with %s.%s",
2760+
ufunc_get_name_cstr(ufunc), method);
27502761
goto fail;
27512762
}
27522763
/* TODO: This really should _not_ be unsafe casting (same above)! */
@@ -2912,7 +2923,7 @@ PyUFunc_Reduce(PyUFuncObject *ufunc,
29122923
}
29132924

29142925
/* Get the identity */
2915-
/* TODO: Both of these must be provided by the ArrayMethod! */
2926+
/* TODO: Both of these should be provided by the ArrayMethod! */
29162927
identity = _get_identity(ufunc, &reorderable);
29172928
if (identity == NULL) {
29182929
return NULL;
@@ -2938,7 +2949,7 @@ PyUFunc_Reduce(PyUFuncObject *ufunc,
29382949

29392950
PyArray_Descr *descrs[3];
29402951
PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc,
2941-
arr, out, signature, 0, descrs);
2952+
arr, out, signature, NPY_FALSE, descrs, "reduce");
29422953
if (ufuncimpl == NULL) {
29432954
Py_DECREF(initial);
29442955
return NULL;
@@ -3002,7 +3013,7 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
30023013

30033014
PyArray_Descr *descrs[3];
30043015
PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc,
3005-
arr, out, signature, 1, descrs);
3016+
arr, out, signature, NPY_TRUE, descrs, "accumulate");
30063017
if (ufuncimpl == NULL) {
30073018
return NULL;
30083019
}
@@ -3019,9 +3030,9 @@ PyUFunc_Accumulate(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
30193030
}
30203031

30213032
PyArrayMethod_Context context = {
3022-
.caller = (PyObject *)ufunc,
3023-
.method = ufuncimpl,
3024-
.descriptors = descrs,
3033+
.caller = (PyObject *)ufunc,
3034+
.method = ufuncimpl,
3035+
.descriptors = descrs,
30253036
};
30263037

30273038
ndim = PyArray_NDIM(arr);
@@ -3415,7 +3426,7 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind,
34153426

34163427
PyArray_Descr *descrs[3];
34173428
PyArrayMethodObject *ufuncimpl = reducelike_promote_and_resolve(ufunc,
3418-
arr, out, signature, 1, descrs);
3429+
arr, out, signature, NPY_TRUE, descrs, "reduceat");
34193430
if (ufuncimpl == NULL) {
34203431
return NULL;
34213432
}
@@ -3432,9 +3443,9 @@ PyUFunc_Reduceat(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *ind,
34323443
}
34333444

34343445
PyArrayMethod_Context context = {
3435-
.caller = (PyObject *)ufunc,
3436-
.method = ufuncimpl,
3437-
.descriptors = descrs,
3446+
.caller = (PyObject *)ufunc,
3447+
.method = ufuncimpl,
3448+
.descriptors = descrs,
34383449
};
34393450

34403451
ndim = PyArray_NDIM(arr);
@@ -4101,7 +4112,6 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc,
41014112
Py_XINCREF(signature[0]);
41024113
signature[2] = signature[0];
41034114

4104-
41054115
switch(operation) {
41064116
case UFUNC_REDUCE:
41074117
ret = PyUFunc_Reduce(ufunc,

numpy/core/tests/test_custom_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_possible_and_impossible_reduce(self):
114114
# possible (the relaxed version of the old refusal to handle any
115115
# flexible dtype).
116116
with pytest.raises(TypeError,
117-
match="The resolved dtypes are not compatible"):
117+
match="the resolved dtypes are not compatible"):
118118
np.multiply.reduce(a)
119119

120120
def test_basic_multiply_promotion(self):

numpy/core/tests/test_datetime.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,21 @@ def test_datetime_maximum_reduce(self):
20292029
assert_equal(np.maximum.reduce(a),
20302030
np.timedelta64(7, 's'))
20312031

2032+
def test_datetime_no_subtract_reducelike(self):
2033+
# subtracting two datetime64 works, but we cannot reduce it, since
2034+
# the result of that subtraction will have a different dtype.
2035+
arr = np.array(["2021-12-02", "2019-05-12"], dtype="M8[ms]")
2036+
msg = r"the resolved dtypes are not compatible with subtract\."
2037+
2038+
with pytest.raises(TypeError, match=msg + "reduce"):
2039+
np.subtract.reduce(arr)
2040+
2041+
with pytest.raises(TypeError, match=msg + "accumulate"):
2042+
np.subtract.accumulate(arr)
2043+
2044+
with pytest.raises(TypeError, match=msg + "reduceat"):
2045+
np.subtract.reduceat(arr, [0])
2046+
20322047
def test_datetime_busday_offset(self):
20332048
# First Monday in June
20342049
assert_equal(

0 commit comments

Comments
 (0)