Skip to content

Commit 0cea652

Browse files
authored
Merge pull request numpy#14933 from seberg/cleanup-converttocommontype
API: Use `ResultType` in `PyArray_ConvertToCommonType`
2 parents 01289c2 + c5da77d commit 0cea652

File tree

4 files changed

+59
-85
lines changed

4 files changed

+59
-85
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Scalar promotion in ``PyArray_ConvertToCommonType``
2+
---------------------------------------------------
3+
4+
The promotion of mixed scalars and arrays in ``PyArray_ConvertToCommonType``
5+
has been changed to adhere to those used by ``np.result_type``.
6+
This means that input such as ``(1000, np.array([1], dtype=np.uint8)))``
7+
will now return ``uint16`` dtypes. In most cases the behaviour is unchanged.
8+
Note that the use of this C-API function is generally discouarged.
9+
This also fixes ``np.choose`` to behave the same way as the rest of NumPy
10+
in this respect.

doc/source/reference/c-api/array.rst

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,14 +1255,18 @@ Converting data types
12551255
12561256
Convert a sequence of Python objects contained in *op* to an array
12571257
of ndarrays each having the same data type. The type is selected
1258-
based on the typenumber (larger type number is chosen over a
1259-
smaller one) ignoring objects that are only scalars. The length of
1260-
the sequence is returned in *n*, and an *n* -length array of
1261-
:c:type:`PyArrayObject` pointers is the return value (or ``NULL`` if an
1262-
error occurs). The returned array must be freed by the caller of
1263-
this routine (using :c:func:`PyDataMem_FREE` ) and all the array objects
1264-
in it ``DECREF`` 'd or a memory-leak will occur. The example
1265-
template-code below shows a typically usage:
1258+
in the same way as `PyArray_ResultType`. The length of the sequence is
1259+
returned in *n*, and an *n* -length array of :c:type:`PyArrayObject`
1260+
pointers is the return value (or ``NULL`` if an error occurs).
1261+
The returned array must be freed by the caller of this routine
1262+
(using :c:func:`PyDataMem_FREE` ) and all the array objects in it
1263+
``DECREF`` 'd or a memory-leak will occur. The example template-code
1264+
below shows a typically usage:
1265+
1266+
.. versionchanged:: 1.18.0
1267+
A mix of scalars and zero-dimensional arrays now produces a type
1268+
capable of holding the scalar value.
1269+
Previously priority was given to the dtype of the arrays.
12661270
12671271
.. code-block:: c
12681272

numpy/core/src/multiarray/convert_datatype.c

Lines changed: 28 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,15 +2121,19 @@ PyArray_ObjectType(PyObject *op, int minimum_type)
21212121

21222122
/* Raises error when len(op) == 0 */
21232123

2124-
/*NUMPY_API*/
2124+
/*NUMPY_API
2125+
*
2126+
* This function is only used in one place within NumPy and should
2127+
* generally be avoided. It is provided mainly for backward compatibility.
2128+
*
2129+
* The user of the function has to free the returned array.
2130+
*/
21252131
NPY_NO_EXPORT PyArrayObject **
21262132
PyArray_ConvertToCommonType(PyObject *op, int *retn)
21272133
{
2128-
int i, n, allscalars = 0;
2134+
int i, n;
2135+
PyArray_Descr *common_descr = NULL;
21292136
PyArrayObject **mps = NULL;
2130-
PyArray_Descr *intype = NULL, *stype = NULL;
2131-
PyArray_Descr *newtype = NULL;
2132-
NPY_SCALARKIND scalarkind = NPY_NOSCALAR, intypekind = NPY_NOSCALAR;
21332137

21342138
*retn = n = PySequence_Length(op);
21352139
if (n == 0) {
@@ -2165,94 +2169,41 @@ PyArray_ConvertToCommonType(PyObject *op, int *retn)
21652169
}
21662170

21672171
for (i = 0; i < n; i++) {
2168-
PyObject *otmp = PySequence_GetItem(op, i);
2169-
if (otmp == NULL) {
2172+
/* Convert everything to an array, this could be optimized away */
2173+
PyObject *tmp = PySequence_GetItem(op, i);
2174+
if (tmp == NULL) {
21702175
goto fail;
21712176
}
2172-
if (!PyArray_CheckAnyScalar(otmp)) {
2173-
newtype = PyArray_DescrFromObject(otmp, intype);
2174-
Py_DECREF(otmp);
2175-
Py_XDECREF(intype);
2176-
if (newtype == NULL) {
2177-
goto fail;
2178-
}
2179-
intype = newtype;
2180-
intypekind = PyArray_ScalarKind(intype->type_num, NULL);
2181-
}
2182-
else {
2183-
newtype = PyArray_DescrFromObject(otmp, stype);
2184-
Py_DECREF(otmp);
2185-
Py_XDECREF(stype);
2186-
if (newtype == NULL) {
2187-
goto fail;
2188-
}
2189-
stype = newtype;
2190-
scalarkind = PyArray_ScalarKind(newtype->type_num, NULL);
2191-
mps[i] = (PyArrayObject *)Py_None;
2192-
Py_INCREF(Py_None);
2193-
}
2194-
}
2195-
if (intype == NULL) {
2196-
/* all scalars */
2197-
allscalars = 1;
2198-
intype = stype;
2199-
Py_INCREF(intype);
2200-
for (i = 0; i < n; i++) {
2201-
Py_XDECREF(mps[i]);
2202-
mps[i] = NULL;
2203-
}
2204-
}
2205-
else if ((stype != NULL) && (intypekind != scalarkind)) {
2206-
/*
2207-
* we need to upconvert to type that
2208-
* handles both intype and stype
2209-
* also don't forcecast the scalars.
2210-
*/
2211-
if (!PyArray_CanCoerceScalar(stype->type_num,
2212-
intype->type_num,
2213-
scalarkind)) {
2214-
newtype = PyArray_PromoteTypes(intype, stype);
2215-
Py_XDECREF(intype);
2216-
intype = newtype;
2217-
if (newtype == NULL) {
2218-
goto fail;
2219-
}
2220-
}
2221-
for (i = 0; i < n; i++) {
2222-
Py_XDECREF(mps[i]);
2223-
mps[i] = NULL;
2177+
2178+
mps[i] = (PyArrayObject *)PyArray_FROM_O(tmp);
2179+
Py_DECREF(tmp);
2180+
if (mps[i] == NULL) {
2181+
goto fail;
22242182
}
22252183
}
22262184

2185+
common_descr = PyArray_ResultType(n, mps, 0, NULL);
2186+
if (common_descr == NULL) {
2187+
goto fail;
2188+
}
22272189

2228-
/* Make sure all arrays are actual array objects. */
2190+
/* Make sure all arrays are contiguous and have the correct dtype. */
22292191
for (i = 0; i < n; i++) {
22302192
int flags = NPY_ARRAY_CARRAY;
2231-
PyObject *otmp = PySequence_GetItem(op, i);
2193+
PyArrayObject *tmp = mps[i];
22322194

2233-
if (otmp == NULL) {
2234-
goto fail;
2235-
}
2236-
if (!allscalars && ((PyObject *)(mps[i]) == Py_None)) {
2237-
/* forcecast scalars */
2238-
flags |= NPY_ARRAY_FORCECAST;
2239-
Py_DECREF(Py_None);
2240-
}
2241-
Py_INCREF(intype);
2242-
mps[i] = (PyArrayObject*)PyArray_FromAny(otmp, intype, 0, 0,
2243-
flags, NULL);
2244-
Py_DECREF(otmp);
2195+
Py_INCREF(common_descr);
2196+
mps[i] = (PyArrayObject *)PyArray_FromArray(tmp, common_descr, flags);
2197+
Py_DECREF(tmp);
22452198
if (mps[i] == NULL) {
22462199
goto fail;
22472200
}
22482201
}
2249-
Py_DECREF(intype);
2250-
Py_XDECREF(stype);
2202+
Py_DECREF(common_descr);
22512203
return mps;
22522204

22532205
fail:
2254-
Py_XDECREF(intype);
2255-
Py_XDECREF(stype);
2206+
Py_XDECREF(common_descr);
22562207
*retn = 0;
22572208
for (i = 0; i < n; i++) {
22582209
Py_XDECREF(mps[i]);

numpy/core/tests/test_multiarray.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6465,6 +6465,15 @@ def test_broadcast2(self):
64656465
A = np.choose(self.ind, (self.x, self.y2))
64666466
assert_equal(A, [[2, 2, 3], [2, 2, 3]])
64676467

6468+
@pytest.mark.parametrize("ops",
6469+
[(1000, np.array([1], dtype=np.uint8)),
6470+
(-1, np.array([1], dtype=np.uint8)),
6471+
(1., np.float32(3)),
6472+
(1., np.array([3], dtype=np.float32))],)
6473+
def test_output_dtype(self, ops):
6474+
expected_dt = np.result_type(*ops)
6475+
assert(np.choose([0], ops).dtype == expected_dt)
6476+
64686477

64696478
class TestRepeat(object):
64706479
def setup(self):

0 commit comments

Comments
 (0)