Skip to content

Commit 8c9cf79

Browse files
authored
Merge pull request numpy#20088 from WarrenWeckesser/fix-gh-200077
BUG: core: result_type(0, np.timedelta64(4)) would seg. fault.
2 parents c3c59da + a37f6d2 commit 8c9cf79

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

numpy/core/src/multiarray/convert_datatype.c

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,40 @@ should_use_min_scalar(npy_intp narrs, PyArrayObject **arr,
15481548
}
15491549

15501550

1551+
/*
1552+
* Utility function used only in PyArray_ResultType for value-based logic.
1553+
* See that function for the meaning and contents of the parameters.
1554+
*/
1555+
static PyArray_Descr *
1556+
get_descr_from_cast_or_value(
1557+
npy_intp i,
1558+
PyArrayObject *arrs[],
1559+
npy_intp ndtypes,
1560+
PyArray_Descr *descriptor,
1561+
PyArray_DTypeMeta *common_dtype)
1562+
{
1563+
PyArray_Descr *curr;
1564+
if (NPY_LIKELY(i < ndtypes ||
1565+
!(PyArray_FLAGS(arrs[i-ndtypes]) & _NPY_ARRAY_WAS_PYSCALAR))) {
1566+
curr = PyArray_CastDescrToDType(descriptor, common_dtype);
1567+
}
1568+
else {
1569+
/*
1570+
* Unlike `PyArray_CastToDTypeAndPromoteDescriptors`, deal with
1571+
* plain Python values "graciously". This recovers the original
1572+
* value the long route, but it should almost never happen...
1573+
*/
1574+
PyObject *tmp = PyArray_GETITEM(arrs[i-ndtypes],
1575+
PyArray_BYTES(arrs[i-ndtypes]));
1576+
if (tmp == NULL) {
1577+
return NULL;
1578+
}
1579+
curr = NPY_DT_CALL_discover_descr_from_pyobject(common_dtype, tmp);
1580+
Py_DECREF(tmp);
1581+
}
1582+
return curr;
1583+
}
1584+
15511585
/*NUMPY_API
15521586
*
15531587
* Produces the result type of a bunch of inputs, using the same rules
@@ -1684,28 +1718,15 @@ PyArray_ResultType(
16841718
result = NPY_DT_CALL_default_descr(common_dtype);
16851719
}
16861720
else {
1687-
result = PyArray_CastDescrToDType(all_descriptors[0], common_dtype);
1721+
result = get_descr_from_cast_or_value(
1722+
0, arrs, ndtypes, all_descriptors[0], common_dtype);
1723+
if (result == NULL) {
1724+
goto error;
1725+
}
16881726

16891727
for (npy_intp i = 1; i < ndtypes+narrs; i++) {
1690-
PyArray_Descr *curr;
1691-
if (NPY_LIKELY(i < ndtypes ||
1692-
!(PyArray_FLAGS(arrs[i-ndtypes]) & _NPY_ARRAY_WAS_PYSCALAR))) {
1693-
curr = PyArray_CastDescrToDType(all_descriptors[i], common_dtype);
1694-
}
1695-
else {
1696-
/*
1697-
* Unlike `PyArray_CastToDTypeAndPromoteDescriptors` deal with
1698-
* plain Python values "graciously". This recovers the original
1699-
* value the long route, but it should almost never happen...
1700-
*/
1701-
PyObject *tmp = PyArray_GETITEM(
1702-
arrs[i-ndtypes], PyArray_BYTES(arrs[i-ndtypes]));
1703-
if (tmp == NULL) {
1704-
goto error;
1705-
}
1706-
curr = NPY_DT_CALL_discover_descr_from_pyobject(common_dtype, tmp);
1707-
Py_DECREF(tmp);
1708-
}
1728+
PyArray_Descr *curr = get_descr_from_cast_or_value(
1729+
i, arrs, ndtypes, all_descriptors[i], common_dtype);
17091730
if (curr == NULL) {
17101731
goto error;
17111732
}

numpy/core/tests/test_dtype.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,14 @@ def test_subscript_scalar(self) -> None:
15901590
assert np.dtype[Any]
15911591

15921592

1593+
def test_result_type_integers_and_unitless_timedelta64():
1594+
# Regression test for gh-20077. The following call of `result_type`
1595+
# would cause a seg. fault.
1596+
td = np.timedelta64(4)
1597+
result = np.result_type(0, td)
1598+
assert_dtype_equal(result, td.dtype)
1599+
1600+
15931601
@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
15941602
def test_class_getitem_38() -> None:
15951603
match = "Type subscription requires python >= 3.9"

0 commit comments

Comments
 (0)