Skip to content

Commit c5da77d

Browse files
committed
API: Use ResultType in PyArray_ConvertToCommonType
This slightly modifies the behaviour of `np.choose` (practically a bug fix) and the public function itself. The function is not used within e.g. SciPy, so the small performance hit of this implementation is probably insignificant. The change should help clean up dtypes a bit, since the whole "scalar cast" logic is brittle and should be deprecated/removed, and this is probably one of the few places actually using it. The choose change means that: ``` np.choose([0], (1000, np.array([1], dtype=np.uint8))) ``` will actually return a value of 1000 (the dtype not being uint8).
1 parent 9aeb751 commit c5da77d

File tree

4 files changed

+59
-85
lines changed

4 files changed

+59
-85
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Scalar promotion in ``PyArray_ConvertToCommonType``
2+
---------------------------------------------------
3+
4+
The promotion of mixed scalars and arrays in ``PyArray_ConvertToCommonType``
5+
has been changed to adhere to those used by ``np.result_type``.
6+
This means that input such as ``(1000, np.array([1], dtype=np.uint8)))``
7+
will now return ``uint16`` dtypes. In most cases the behaviour is unchanged.
8+
Note that the use of this C-API function is generally discouarged.
9+
This also fixes ``np.choose`` to behave the same way as the rest of NumPy
10+
in this respect.

doc/source/reference/c-api/array.rst

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,14 +1255,18 @@ Converting data types
12551255
12561256
Convert a sequence of Python objects contained in *op* to an array
12571257
of ndarrays each having the same data type. The type is selected
1258-
based on the typenumber (larger type number is chosen over a
1259-
smaller one) ignoring objects that are only scalars. The length of
1260-
the sequence is returned in *n*, and an *n* -length array of
1261-
:c:type:`PyArrayObject` pointers is the return value (or ``NULL`` if an
1262-
error occurs). The returned array must be freed by the caller of
1263-
this routine (using :c:func:`PyDataMem_FREE` ) and all the array objects
1264-
in it ``DECREF`` 'd or a memory-leak will occur. The example
1265-
template-code below shows a typically usage:
1258+
in the same way as `PyArray_ResultType`. The length of the sequence is
1259+
returned in *n*, and an *n* -length array of :c:type:`PyArrayObject`
1260+
pointers is the return value (or ``NULL`` if an error occurs).
1261+
The returned array must be freed by the caller of this routine
1262+
(using :c:func:`PyDataMem_FREE` ) and all the array objects in it
1263+
``DECREF`` 'd or a memory-leak will occur. The example template-code
1264+
below shows a typically usage:
1265+
1266+
.. versionchanged:: 1.18.0
1267+
A mix of scalars and zero-dimensional arrays now produces a type
1268+
capable of holding the scalar value.
1269+
Previously priority was given to the dtype of the arrays.
12661270
12671271
.. code-block:: c
12681272

numpy/core/src/multiarray/convert_datatype.c

Lines changed: 28 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,15 +2115,19 @@ PyArray_ObjectType(PyObject *op, int minimum_type)
21152115

21162116
/* Raises error when len(op) == 0 */
21172117

2118-
/*NUMPY_API*/
2118+
/*NUMPY_API
2119+
*
2120+
* This function is only used in one place within NumPy and should
2121+
* generally be avoided. It is provided mainly for backward compatibility.
2122+
*
2123+
* The user of the function has to free the returned array.
2124+
*/
21192125
NPY_NO_EXPORT PyArrayObject **
21202126
PyArray_ConvertToCommonType(PyObject *op, int *retn)
21212127
{
2122-
int i, n, allscalars = 0;
2128+
int i, n;
2129+
PyArray_Descr *common_descr = NULL;
21232130
PyArrayObject **mps = NULL;
2124-
PyArray_Descr *intype = NULL, *stype = NULL;
2125-
PyArray_Descr *newtype = NULL;
2126-
NPY_SCALARKIND scalarkind = NPY_NOSCALAR, intypekind = NPY_NOSCALAR;
21272131

21282132
*retn = n = PySequence_Length(op);
21292133
if (n == 0) {
@@ -2159,94 +2163,41 @@ PyArray_ConvertToCommonType(PyObject *op, int *retn)
21592163
}
21602164

21612165
for (i = 0; i < n; i++) {
2162-
PyObject *otmp = PySequence_GetItem(op, i);
2163-
if (otmp == NULL) {
2166+
/* Convert everything to an array, this could be optimized away */
2167+
PyObject *tmp = PySequence_GetItem(op, i);
2168+
if (tmp == NULL) {
21642169
goto fail;
21652170
}
2166-
if (!PyArray_CheckAnyScalar(otmp)) {
2167-
newtype = PyArray_DescrFromObject(otmp, intype);
2168-
Py_DECREF(otmp);
2169-
Py_XDECREF(intype);
2170-
if (newtype == NULL) {
2171-
goto fail;
2172-
}
2173-
intype = newtype;
2174-
intypekind = PyArray_ScalarKind(intype->type_num, NULL);
2175-
}
2176-
else {
2177-
newtype = PyArray_DescrFromObject(otmp, stype);
2178-
Py_DECREF(otmp);
2179-
Py_XDECREF(stype);
2180-
if (newtype == NULL) {
2181-
goto fail;
2182-
}
2183-
stype = newtype;
2184-
scalarkind = PyArray_ScalarKind(newtype->type_num, NULL);
2185-
mps[i] = (PyArrayObject *)Py_None;
2186-
Py_INCREF(Py_None);
2187-
}
2188-
}
2189-
if (intype == NULL) {
2190-
/* all scalars */
2191-
allscalars = 1;
2192-
intype = stype;
2193-
Py_INCREF(intype);
2194-
for (i = 0; i < n; i++) {
2195-
Py_XDECREF(mps[i]);
2196-
mps[i] = NULL;
2197-
}
2198-
}
2199-
else if ((stype != NULL) && (intypekind != scalarkind)) {
2200-
/*
2201-
* we need to upconvert to type that
2202-
* handles both intype and stype
2203-
* also don't forcecast the scalars.
2204-
*/
2205-
if (!PyArray_CanCoerceScalar(stype->type_num,
2206-
intype->type_num,
2207-
scalarkind)) {
2208-
newtype = PyArray_PromoteTypes(intype, stype);
2209-
Py_XDECREF(intype);
2210-
intype = newtype;
2211-
if (newtype == NULL) {
2212-
goto fail;
2213-
}
2214-
}
2215-
for (i = 0; i < n; i++) {
2216-
Py_XDECREF(mps[i]);
2217-
mps[i] = NULL;
2171+
2172+
mps[i] = (PyArrayObject *)PyArray_FROM_O(tmp);
2173+
Py_DECREF(tmp);
2174+
if (mps[i] == NULL) {
2175+
goto fail;
22182176
}
22192177
}
22202178

2179+
common_descr = PyArray_ResultType(n, mps, 0, NULL);
2180+
if (common_descr == NULL) {
2181+
goto fail;
2182+
}
22212183

2222-
/* Make sure all arrays are actual array objects. */
2184+
/* Make sure all arrays are contiguous and have the correct dtype. */
22232185
for (i = 0; i < n; i++) {
22242186
int flags = NPY_ARRAY_CARRAY;
2225-
PyObject *otmp = PySequence_GetItem(op, i);
2187+
PyArrayObject *tmp = mps[i];
22262188

2227-
if (otmp == NULL) {
2228-
goto fail;
2229-
}
2230-
if (!allscalars && ((PyObject *)(mps[i]) == Py_None)) {
2231-
/* forcecast scalars */
2232-
flags |= NPY_ARRAY_FORCECAST;
2233-
Py_DECREF(Py_None);
2234-
}
2235-
Py_INCREF(intype);
2236-
mps[i] = (PyArrayObject*)PyArray_FromAny(otmp, intype, 0, 0,
2237-
flags, NULL);
2238-
Py_DECREF(otmp);
2189+
Py_INCREF(common_descr);
2190+
mps[i] = (PyArrayObject *)PyArray_FromArray(tmp, common_descr, flags);
2191+
Py_DECREF(tmp);
22392192
if (mps[i] == NULL) {
22402193
goto fail;
22412194
}
22422195
}
2243-
Py_DECREF(intype);
2244-
Py_XDECREF(stype);
2196+
Py_DECREF(common_descr);
22452197
return mps;
22462198

22472199
fail:
2248-
Py_XDECREF(intype);
2249-
Py_XDECREF(stype);
2200+
Py_XDECREF(common_descr);
22502201
*retn = 0;
22512202
for (i = 0; i < n; i++) {
22522203
Py_XDECREF(mps[i]);

numpy/core/tests/test_multiarray.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6488,6 +6488,15 @@ def test_broadcast2(self):
64886488
A = np.choose(self.ind, (self.x, self.y2))
64896489
assert_equal(A, [[2, 2, 3], [2, 2, 3]])
64906490

6491+
@pytest.mark.parametrize("ops",
6492+
[(1000, np.array([1], dtype=np.uint8)),
6493+
(-1, np.array([1], dtype=np.uint8)),
6494+
(1., np.float32(3)),
6495+
(1., np.array([3], dtype=np.float32))],)
6496+
def test_output_dtype(self, ops):
6497+
expected_dt = np.result_type(*ops)
6498+
assert(np.choose([0], ops).dtype == expected_dt)
6499+
64916500

64926501
class TestRepeat(object):
64936502
def setup(self):

0 commit comments

Comments
 (0)