Skip to content

Commit 302a297

Browse files
committed
implement NPY_ARRAY_SAME_KIND_CASTING and use in np.take
1 parent eacef41 commit 302a297

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

numpy/_core/src/multiarray/arrayobject.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ static const int NPY_ARRAY_WAS_PYTHON_COMPLEX = (1 << 28);
5555
static const int NPY_ARRAY_WAS_INT_AND_REPLACED = (1 << 27);
5656
static const int NPY_ARRAY_WAS_PYTHON_LITERAL = (1 << 30 | 1 << 29 | 1 << 28);
5757

58+
/*
59+
* This flag allows same kind casting, similar to NPY_ARRAY_FORCECAST.
60+
*
61+
* An array never has this flag set; they're only used as parameter
62+
* flags to the various FromAny functions.
63+
*/
64+
static const int NPY_ARRAY_SAME_KIND_CASTING = (1 << 26);
65+
5866
#ifdef __cplusplus
5967
}
6068
#endif

numpy/_core/src/multiarray/ctors.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <structmember.h>
88

99
#include "numpy/arrayobject.h"
10+
#include "arrayobject.h"
1011
#include "numpy/arrayscalars.h"
1112

1213
#include "numpy/npy_math.h"
@@ -1908,6 +1909,10 @@ PyArray_FromArray(PyArrayObject *arr, PyArray_Descr *newtype, int flags)
19081909
newtype->elsize = oldtype->elsize;
19091910
}
19101911

1912+
if (flags & NPY_ARRAY_SAME_KIND_CASTING) {
1913+
casting = NPY_SAME_KIND_CASTING;
1914+
}
1915+
19111916
/* If the casting if forced, use the 'unsafe' casting rule */
19121917
if (flags & NPY_ARRAY_FORCECAST) {
19131918
casting = NPY_UNSAFE_CASTING;

numpy/_core/src/multiarray/item_selection.c

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -244,23 +244,11 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
244244
return NULL;
245245
}
246246

247-
#if NPY_SIZEOF_INTP == NPY_SIZEOF_INT
248-
PyArrayObject *tmp;
249-
tmp = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
250-
NPY_INT64,
251-
0, 0);
252-
if (tmp == NULL) {
253-
goto fail;
254-
}
255-
256-
indices = (PyArrayObject *)PyArray_Cast(tmp, NPY_INTP);
257-
Py_DECREF(tmp);
258-
#else
259-
indices = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
260-
NPY_INT64,
261-
0, 0);
262-
#endif
263-
247+
indices = (PyArrayObject *)PyArray_FromAny(indices0,
248+
PyArray_DescrFromType(NPY_INTP),
249+
0, 0,
250+
NPY_ARRAY_SAME_KIND_CASTING,
251+
NULL);
264252
if (indices == NULL) {
265253
goto fail;
266254
}

0 commit comments

Comments
 (0)