Skip to content

Commit 0516b05

Browse files
authored
Merge pull request numpy#26610 from JuliaPoo/issue-25607-take-32bit-bug
BUG: np.take handle 64-bit indices on 32-bit platforms
2 parents 9ae367a + 62b628c commit 0516b05

File tree

4 files changed

+31
-3
lines changed

4 files changed

+31
-3
lines changed

numpy/_core/src/multiarray/arrayobject.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ static const int NPY_ARRAY_WAS_PYTHON_COMPLEX = (1 << 28);
5555
static const int NPY_ARRAY_WAS_INT_AND_REPLACED = (1 << 27);
5656
static const int NPY_ARRAY_WAS_PYTHON_LITERAL = (1 << 30 | 1 << 29 | 1 << 28);
5757

58+
/*
59+
* This flag allows same kind casting, similar to NPY_ARRAY_FORCECAST.
60+
*
61+
* An array never has this flag set; they're only used as parameter
62+
* flags to the various FromAny functions.
63+
*/
64+
static const int NPY_ARRAY_SAME_KIND_CASTING = (1 << 26);
65+
5866
#ifdef __cplusplus
5967
}
6068
#endif

numpy/_core/src/multiarray/ctors.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <structmember.h>
88

99
#include "numpy/arrayobject.h"
10+
#include "arrayobject.h"
1011
#include "numpy/arrayscalars.h"
1112

1213
#include "numpy/npy_math.h"
@@ -1908,6 +1909,10 @@ PyArray_FromArray(PyArrayObject *arr, PyArray_Descr *newtype, int flags)
19081909
newtype->elsize = oldtype->elsize;
19091910
}
19101911

1912+
if (flags & NPY_ARRAY_SAME_KIND_CASTING) {
1913+
casting = NPY_SAME_KIND_CASTING;
1914+
}
1915+
19111916
/* If the casting if forced, use the 'unsafe' casting rule */
19121917
if (flags & NPY_ARRAY_FORCECAST) {
19131918
casting = NPY_UNSAFE_CASTING;

numpy/_core/src/multiarray/item_selection.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,12 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
243243
if (self == NULL) {
244244
return NULL;
245245
}
246-
indices = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
247-
NPY_INTP,
248-
0, 0);
246+
247+
indices = (PyArrayObject *)PyArray_FromAny(indices0,
248+
PyArray_DescrFromType(NPY_INTP),
249+
0, 0,
250+
NPY_ARRAY_SAME_KIND_CASTING | NPY_ARRAY_DEFAULT,
251+
NULL);
249252
if (indices == NULL) {
250253
goto fail;
251254
}

numpy/_core/tests/test_numeric.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,18 @@ def test_take(self):
324324
out = np.take(a, indices)
325325
assert_equal(out, tgt)
326326

327+
pairs = [
328+
(np.int32, np.int32), (np.int32, np.int64),
329+
(np.int64, np.int32), (np.int64, np.int64)
330+
]
331+
for array_type, indices_type in pairs:
332+
x = np.array([1, 2, 3, 4, 5], dtype=array_type)
333+
ind = np.array([0, 2, 2, 3], dtype=indices_type)
334+
tgt = np.array([1, 3, 3, 4], dtype=array_type)
335+
out = np.take(x, ind)
336+
assert_equal(out, tgt)
337+
assert_equal(out.dtype, tgt.dtype)
338+
327339
def test_trace(self):
328340
c = [[1, 2], [3, 4], [5, 6]]
329341
assert_equal(np.trace(c), 5)

0 commit comments

Comments
 (0)