Skip to content

Commit b9d07ab

Browse files
authored
Merge pull request numpy#27091 from seberg/copyto-safety
API,BUG: Fix copyto (and ufunc) handling of scalar cast safety
2 parents 41cc67a + 42b58e3 commit b9d07ab

File tree

9 files changed

+344
-87
lines changed

9 files changed

+344
-87
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
Cast-safety fixes in ``copyto`` and ``full``
2+
--------------------------------------------
3+
``copyto`` now uses NEP 50 correctly and applies this to its cast safety.
4+
Python integer to NumPy integer casts and Python float to NumPy float casts
5+
are now considered "safe" even if assignment may fail or precision may be lost.
6+
This means the following examples change slightly:
7+
8+
* ``np.copyto(int8_arr, 1000)`` previously performed an unsafe/same-kind cast
9+
of the Python integer. It will now always raise, to achieve an unsafe cast
10+
you must pass an array or NumPy scalar.
11+
* ``np.copyto(uint8_arr, 1000, casting="safe")`` will raise an OverflowError
12+
rather than a TypeError due to same-kind casting.
13+
* ``np.copyto(float32_arr, 1e300, casting="safe")`` will overflow to ``inf``
14+
(float32 cannot hold ``1e300``) rather raising a TypeError.
15+
16+
Further, only the dtype is used when assigning NumPy scalars (or 0-d arrays),
17+
meaning that the following behaves differently:
18+
19+
* ``np.copyto(float32_arr, np.float64(3.0), casting="safe")`` raises.
20+
* ``np.coptyo(int8_arr, np.int64(100), casting="safe")`` raises.
21+
Previously, NumPy checked whether the 100 fits the ``int8_arr``.
22+
23+
This aligns ``copyto``, ``full``, and ``full_like`` with the correct NumPy 2
24+
behavior.

numpy/_core/src/multiarray/abstractdtypes.c

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,117 @@ NPY_NO_EXPORT PyArray_DTypeMeta PyArray_PyComplexDType = {{{
378378
.dt_slots = &pycomplexdtype_slots,
379379
.scalar_type = NULL, /* set in initialize_and_map_pytypes_to_dtypes */
380380
};
381+
382+
383+
/*
384+
* Additional functions to deal with Python literal int, float, complex
385+
*/
386+
/*
387+
* This function takes an existing array operand and if the new descr does
388+
* not match, replaces it with a new array that has the correct descriptor
389+
* and holds exactly the scalar value.
390+
*/
391+
NPY_NO_EXPORT int
392+
npy_update_operand_for_scalar(
393+
PyArrayObject **operand, PyObject *scalar, PyArray_Descr *descr,
394+
NPY_CASTING casting)
395+
{
396+
if (PyArray_EquivTypes(PyArray_DESCR(*operand), descr)) {
397+
/*
398+
* TODO: This is an unfortunate work-around for legacy type resolvers
399+
* (see `convert_ufunc_arguments` in `ufunc_object.c`), that
400+
* currently forces us to replace the array.
401+
*/
402+
if (!(PyArray_FLAGS(*operand) & NPY_ARRAY_WAS_PYTHON_INT)) {
403+
return 0;
404+
}
405+
}
406+
else if (NPY_UNLIKELY(casting == NPY_EQUIV_CASTING) &&
407+
descr->type_num != NPY_OBJECT) {
408+
/*
409+
* increadibly niche, but users could pass equiv casting and we
410+
* actually need to cast. Let object pass (technically correct) but
411+
* in all other cases, we don't technically consider equivalent.
412+
* NOTE(seberg): I don't think we should be beholden to this logic.
413+
*/
414+
PyErr_Format(PyExc_TypeError,
415+
"cannot cast Python %s to %S under the casting rule 'equiv'",
416+
Py_TYPE(scalar)->tp_name, descr);
417+
return -1;
418+
}
419+
420+
Py_INCREF(descr);
421+
PyArrayObject *new = (PyArrayObject *)PyArray_NewFromDescr(
422+
&PyArray_Type, descr, 0, NULL, NULL, NULL, 0, NULL);
423+
Py_SETREF(*operand, new);
424+
if (*operand == NULL) {
425+
return -1;
426+
}
427+
if (scalar == NULL) {
428+
/* The ufunc.resolve_dtypes paths can go here. Anything should go. */
429+
return 0;
430+
}
431+
return PyArray_SETITEM(new, PyArray_BYTES(*operand), scalar);
432+
}
433+
434+
435+
/*
436+
* When a user passed a Python literal (int, float, complex), special promotion
437+
* rules mean that we don't know the exact descriptor that should be used.
438+
*
439+
* Typically, this just doesn't really matter. Unfortunately, there are two
440+
* exceptions:
441+
* 1. The user might have passed `signature=` which may not be compatible.
442+
* In that case, we cannot really assume "safe" casting.
443+
* 2. It is at least fathomable that a DType doesn't deal with this directly.
444+
* or that using the original int64/object is wrong in the type resolution.
445+
*
446+
* The solution is to assume that we can use the common DType of the signature
447+
* and the Python scalar DType (`in_DT`) as a safe intermediate.
448+
*/
449+
NPY_NO_EXPORT PyArray_Descr *
450+
npy_find_descr_for_scalar(
451+
PyObject *scalar, PyArray_Descr *original_descr,
452+
PyArray_DTypeMeta *in_DT, PyArray_DTypeMeta *op_DT)
453+
{
454+
PyArray_Descr *res;
455+
/* There is a good chance, descriptors already match... */
456+
if (NPY_DTYPE(original_descr) == op_DT) {
457+
Py_INCREF(original_descr);
458+
return original_descr;
459+
}
460+
461+
PyArray_DTypeMeta *common = PyArray_CommonDType(in_DT, op_DT);
462+
if (common == NULL) {
463+
PyErr_Clear();
464+
/* This is fine. We simply assume the original descr is viable. */
465+
Py_INCREF(original_descr);
466+
return original_descr;
467+
}
468+
/* A very likely case is that there is nothing to do: */
469+
if (NPY_DTYPE(original_descr) == common) {
470+
Py_DECREF(common);
471+
Py_INCREF(original_descr);
472+
return original_descr;
473+
}
474+
if (!NPY_DT_is_parametric(common) ||
475+
/* In some paths we only have a scalar type, can't discover */
476+
scalar == NULL ||
477+
/* If the DType doesn't know the scalar type, guess at default. */
478+
!NPY_DT_CALL_is_known_scalar_type(common, Py_TYPE(scalar))) {
479+
if (common->singleton != NULL) {
480+
Py_INCREF(common->singleton);
481+
res = common->singleton;
482+
Py_INCREF(res);
483+
}
484+
else {
485+
res = NPY_DT_CALL_default_descr(common);
486+
}
487+
}
488+
else {
489+
res = NPY_DT_CALL_discover_descr_from_pyobject(common, scalar);
490+
}
491+
492+
Py_DECREF(common);
493+
return res;
494+
}

numpy/_core/src/multiarray/abstractdtypes.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef NUMPY_CORE_SRC_MULTIARRAY_ABSTRACTDTYPES_H_
22
#define NUMPY_CORE_SRC_MULTIARRAY_ABSTRACTDTYPES_H_
33

4+
#include "numpy/ndarraytypes.h"
45
#include "arrayobject.h"
56
#include "dtypemeta.h"
67

@@ -68,6 +69,19 @@ npy_mark_tmp_array_if_pyscalar(
6869
return 0;
6970
}
7071

72+
73+
NPY_NO_EXPORT int
74+
npy_update_operand_for_scalar(
75+
PyArrayObject **operand, PyObject *scalar, PyArray_Descr *descr,
76+
NPY_CASTING casting);
77+
78+
79+
NPY_NO_EXPORT PyArray_Descr *
80+
npy_find_descr_for_scalar(
81+
PyObject *scalar, PyArray_Descr *original_descr,
82+
PyArray_DTypeMeta *in_DT, PyArray_DTypeMeta *op_DT);
83+
84+
7185
#ifdef __cplusplus
7286
}
7387
#endif

numpy/_core/src/multiarray/array_assign_scalar.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ PyArray_AssignRawScalar(PyArrayObject *dst,
243243
}
244244

245245
/* Check the casting rule */
246-
if (!can_cast_scalar_to(src_dtype, src_data,
247-
PyArray_DESCR(dst), casting)) {
246+
if (!PyArray_CanCastTypeTo(src_dtype, PyArray_DESCR(dst), casting)) {
248247
npy_set_invalid_cast_error(
249248
src_dtype, PyArray_DESCR(dst), casting, NPY_TRUE);
250249
return -1;

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,29 +1929,63 @@ array_asfortranarray(PyObject *NPY_UNUSED(ignored),
19291929

19301930

19311931
static PyObject *
1932-
array_copyto(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
1932+
array_copyto(PyObject *NPY_UNUSED(ignored),
1933+
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
19331934
{
1934-
static char *kwlist[] = {"dst", "src", "casting", "where", NULL};
1935-
PyObject *wheremask_in = NULL;
1936-
PyArrayObject *dst = NULL, *src = NULL, *wheremask = NULL;
1935+
PyObject *dst_obj, *src_obj, *wheremask_in = NULL;
1936+
PyArrayObject *src = NULL, *wheremask = NULL;
19371937
NPY_CASTING casting = NPY_SAME_KIND_CASTING;
1938+
NPY_PREPARE_ARGPARSER;
19381939

1939-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O&|O&O:copyto", kwlist,
1940-
&PyArray_Type, &dst,
1941-
&PyArray_Converter, &src,
1942-
&PyArray_CastingConverter, &casting,
1943-
&wheremask_in)) {
1940+
if (npy_parse_arguments("copyto", args, len_args, kwnames,
1941+
"dst", NULL, &dst_obj,
1942+
"src", NULL, &src_obj,
1943+
"|casting", &PyArray_CastingConverter, &casting,
1944+
"|where", NULL, &wheremask_in,
1945+
NULL, NULL, NULL) < 0) {
19441946
goto fail;
19451947
}
19461948

1949+
if (!PyArray_Check(dst_obj)) {
1950+
PyErr_Format(PyExc_TypeError,
1951+
"copyto() argument 1 must be a numpy.ndarray, not %s",
1952+
Py_TYPE(dst_obj)->tp_name);
1953+
}
1954+
PyArrayObject *dst = (PyArrayObject *)dst_obj;
1955+
1956+
src = (PyArrayObject *)PyArray_FromAny(src_obj, NULL, 0, 0, 0, NULL);
1957+
if (src == NULL) {
1958+
goto fail;
1959+
}
1960+
PyArray_DTypeMeta *DType = NPY_DTYPE(PyArray_DESCR(src));
1961+
Py_INCREF(DType);
1962+
if (npy_mark_tmp_array_if_pyscalar(src_obj, src, &DType)) {
1963+
/* The user passed a Python scalar */
1964+
PyArray_Descr *descr = npy_find_descr_for_scalar(
1965+
src_obj, PyArray_DESCR(src), DType,
1966+
NPY_DTYPE(PyArray_DESCR(dst)));
1967+
Py_DECREF(DType);
1968+
if (descr == NULL) {
1969+
goto fail;
1970+
}
1971+
int res = npy_update_operand_for_scalar(&src, src_obj, descr, casting);
1972+
Py_DECREF(descr);
1973+
if (res < 0) {
1974+
goto fail;
1975+
}
1976+
}
1977+
else {
1978+
Py_DECREF(DType);
1979+
}
1980+
19471981
if (wheremask_in != NULL) {
19481982
/* Get the boolean where mask */
1949-
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_BOOL);
1950-
if (dtype == NULL) {
1983+
PyArray_Descr *descr = PyArray_DescrFromType(NPY_BOOL);
1984+
if (descr == NULL) {
19511985
goto fail;
19521986
}
19531987
wheremask = (PyArrayObject *)PyArray_FromAny(wheremask_in,
1954-
dtype, 0, 0, 0, NULL);
1988+
descr, 0, 0, 0, NULL);
19551989
if (wheremask == NULL) {
19561990
goto fail;
19571991
}
@@ -4431,7 +4465,7 @@ static struct PyMethodDef array_module_methods[] = {
44314465
METH_FASTCALL | METH_KEYWORDS, NULL},
44324466
{"copyto",
44334467
(PyCFunction)array_copyto,
4434-
METH_VARARGS|METH_KEYWORDS, NULL},
4468+
METH_FASTCALL | METH_KEYWORDS, NULL},
44354469
{"nested_iters",
44364470
(PyCFunction)NpyIter_NestedIters,
44374471
METH_VARARGS|METH_KEYWORDS, NULL},
@@ -5129,7 +5163,7 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
51295163

51305164
// initialize static reference to a zero-like array
51315165
npy_static_pydata.zero_pyint_like_arr = PyArray_ZEROS(
5132-
0, NULL, NPY_LONG, NPY_FALSE);
5166+
0, NULL, NPY_DEFAULT_INT, NPY_FALSE);
51335167
if (npy_static_pydata.zero_pyint_like_arr == NULL) {
51345168
goto err;
51355169
}

0 commit comments

Comments
 (0)