Skip to content

Commit 27f3b03

Browse files
committed
TST: Add a test covering logical ufuncs for custom DTypes
In particular, this covers a casting error that currently cannot be hit for normal ufuncs, because they already check casting during the legacy dtype resolution (which is called first).
1 parent 1242ba3 commit 27f3b03

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

numpy/core/src/umath/_scaled_float_dtype.c

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,42 @@ float_to_from_sfloat_resolve_descriptors(
398398
}
399399

400400

401+
/*
402+
* Cast to boolean (for testing the logical functions a bit better).
403+
*/
404+
static int
405+
cast_sfloat_to_bool(PyArrayMethod_Context *NPY_UNUSED(context),
406+
char *const data[], npy_intp const dimensions[],
407+
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
408+
{
409+
npy_intp N = dimensions[0];
410+
char *in = data[0];
411+
char *out = data[1];
412+
for (npy_intp i = 0; i < N; i++) {
413+
*(npy_bool *)out = *(double *)in != 0;
414+
in += strides[0];
415+
out += strides[1];
416+
}
417+
return 0;
418+
}
419+
420+
static NPY_CASTING
421+
sfloat_to_bool_resolve_descriptors(
422+
PyArrayMethodObject *NPY_UNUSED(self),
423+
PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]),
424+
PyArray_Descr *given_descrs[2],
425+
PyArray_Descr *loop_descrs[2])
426+
{
427+
Py_INCREF(given_descrs[0]);
428+
loop_descrs[0] = given_descrs[0];
429+
if (loop_descrs[0] == NULL) {
430+
return -1;
431+
}
432+
loop_descrs[1] = PyArray_DescrFromType(NPY_BOOL); /* cannot fail */
433+
return NPY_UNSAFE_CASTING;
434+
}
435+
436+
401437
static int
402438
init_casts(void)
403439
{
@@ -453,6 +489,22 @@ init_casts(void)
453489
return -1;
454490
}
455491

492+
slots[0].slot = NPY_METH_resolve_descriptors;
493+
slots[0].pfunc = &sfloat_to_bool_resolve_descriptors;
494+
slots[1].slot = NPY_METH_strided_loop;
495+
slots[1].pfunc = &cast_sfloat_to_bool;
496+
slots[2].slot = 0;
497+
slots[2].pfunc = NULL;
498+
499+
spec.name = "sfloat_to_bool_cast";
500+
dtypes[0] = &PyArray_SFloatDType;
501+
dtypes[1] = PyArray_DTypeFromTypeNum(NPY_BOOL);
502+
Py_DECREF(dtypes[1]); /* immortal anyway */
503+
504+
if (PyArray_AddCastingImplementation_FromSpec(&spec, 0)) {
505+
return -1;
506+
}
507+
456508
return 0;
457509
}
458510

numpy/core/tests/test_custom_dtypes.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,23 @@ def test_addition_cast_safety(self):
161161
# Check that casting the output fails also (done by the ufunc here)
162162
with pytest.raises(TypeError):
163163
np.add(a, a, out=c, casting="safe")
164+
165+
@pytest.mark.parametrize("ufunc",
166+
[np.logical_and, np.logical_or, np.logical_xor])
167+
def test_logical_ufuncs_casts_to_bool(self, ufunc):
168+
a = self._get_array(2.)
169+
a[0] = 0. # make sure first element is considered False.
170+
171+
float_equiv = a.astype(float)
172+
expected = ufunc(float_equiv, float_equiv)
173+
res = ufunc(a, a)
174+
assert_array_equal(res, expected)
175+
176+
# also check that the same works for reductions:
177+
expected = ufunc.reduce(float_equiv)
178+
res = ufunc.reduce(a)
179+
assert_array_equal(res, expected)
180+
181+
# The output casting does not match the bool, bool -> bool loop:
182+
with pytest.raises(TypeError):
183+
ufunc(a, a, out=np.empty(a.shape, dtype=int), casting="equiv")

0 commit comments

Comments
 (0)