Skip to content

Commit 8b7fb0e

Browse files
committed
BUG: Fix user dtype can-cast with python scalar during promotion
The can-cast code for "Python scalars" was old and did not correctly take into account possible user-dtypes with respect to NEP 50 weak promotion. To do this, we already had the necessary helper functions that go via promotion (although it took me some brooding to remember ;)). So the fix is rather simple. Actually adding CI/test for the fix is unfortunately hard as it requires such a user DType.
1 parent b9ba526 commit 8b7fb0e

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

numpy/_core/src/multiarray/convert_datatype.c

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -714,18 +714,29 @@ can_cast_pyscalar_scalar_to(
714714
}
715715

716716
/*
717-
* For all other cases we use the default dtype.
717+
* For all other cases we need to make a bit of a dance to find the cast
718+
* safety. We do so by finding the descriptor for the "scalar" (without
719+
* a value; for parametric user dtypes a value may be needed eventually).
718720
*/
719-
PyArray_Descr *from;
721+
PyArray_DTypeMeta *from_DType;
722+
PyArray_Descr *default_dtype;
720723
if (flags & NPY_ARRAY_WAS_PYTHON_INT) {
721-
from = PyArray_DescrFromType(NPY_LONG);
724+
default_dtype = PyArray_DescrNewFromType(NPY_INTP);
725+
from_DType = &PyArray_PyLongDType;
722726
}
723727
else if (flags & NPY_ARRAY_WAS_PYTHON_FLOAT) {
724-
from = PyArray_DescrFromType(NPY_DOUBLE);
728+
default_dtype = PyArray_DescrNewFromType(NPY_FLOAT64);
729+
from_DType = &PyArray_PyFloatDType;
725730
}
726731
else {
727-
from = PyArray_DescrFromType(NPY_CDOUBLE);
732+
default_dtype = PyArray_DescrNewFromType(NPY_COMPLEX128);
733+
from_DType = &PyArray_PyComplexDType;
728734
}
735+
736+
PyArray_Descr *from = npy_find_descr_for_scalar(
737+
NULL, default_dtype, from_DType, NPY_DTYPE(to));
738+
Py_DECREF(default_dtype);
739+
729740
int res = PyArray_CanCastTypeTo(from, to, casting);
730741
Py_DECREF(from);
731742
return res;

0 commit comments

Comments
 (0)