Skip to content

Commit 0dd77b4

Browse files
committed
BUG: np.take handle 64-bit indices on 32-bit platforms
See numpy#25607
1 parent 76807f0 commit 0dd77b4

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

numpy/_core/src/multiarray/item_selection.c

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,21 +231,32 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
231231
PyArrayObject *out, NPY_CLIPMODE clipmode)
232232
{
233233
PyArray_Descr *dtype;
234-
PyArrayObject *obj = NULL, *self, *indices;
234+
PyArrayObject *obj = NULL, *self, *indices, *tmp;
235235
npy_intp nd, i, n, m, max_item, chunk, itemsize, nelem;
236236
npy_intp shape[NPY_MAXDIMS];
237237

238238
npy_bool needs_refcounting;
239239

240-
indices = NULL;
240+
indices = tmp = NULL;
241241
self = (PyArrayObject *)PyArray_CheckAxis(self0, &axis,
242242
NPY_ARRAY_CARRAY_RO);
243243
if (self == NULL) {
244244
return NULL;
245245
}
246-
indices = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
247-
NPY_INTP,
248-
0, 0);
246+
tmp = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
247+
NPY_INT64,
248+
0, 0);
249+
if (tmp == NULL) {
250+
goto fail;
251+
}
252+
253+
if (NPY_SIZEOF_INTP != 8) {
254+
indices = (PyArrayObject *)PyArray_Cast(tmp, NPY_INTP);
255+
Py_DECREF(tmp);
256+
} else {
257+
indices = tmp;
258+
}
259+
249260
if (indices == NULL) {
250261
goto fail;
251262
}

numpy/_core/tests/test_numeric.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,38 @@ def test_take(self):
324324
out = np.take(a, indices)
325325
assert_equal(out, tgt)
326326

327+
# take 32 64
328+
x32 = np.array([1, 2, 3, 4, 5], dtype=np.int32)
329+
x64 = np.array([0, 2, 2, 3], dtype=np.int64)
330+
tgt = np.array([1, 3, 3, 4], dtype=np.int32)
331+
out = np.take(x32, x64)
332+
assert_equal(out, tgt)
333+
assert_equal(out.dtype, tgt.dtype)
334+
335+
# take 64 32
336+
x64 = np.array([1, 2, 3, 4, 5], dtype=np.int64)
337+
x32 = np.array([0, 2, 2, 3], dtype=np.int32)
338+
tgt = np.array([1, 3, 3, 4], dtype=np.int64)
339+
out = np.take(x64, x32)
340+
assert_equal(out, tgt)
341+
assert_equal(out.dtype, tgt.dtype)
342+
343+
# take 32 32
344+
x32_0 = np.array([1, 2, 3, 4, 5], dtype=np.int32)
345+
x32_1 = np.array([0, 2, 2, 3], dtype=np.int32)
346+
tgt = np.array([1, 3, 3, 4], dtype=np.int32)
347+
out = np.take(x32_0, x32_1)
348+
assert_equal(out, tgt)
349+
assert_equal(out.dtype, tgt.dtype)
350+
351+
# take 64 64
352+
x64_0 = np.array([1, 2, 3, 4, 5], dtype=np.int64)
353+
x64_1 = np.array([0, 2, 2, 3], dtype=np.int64)
354+
tgt = np.array([1, 3, 3, 4], dtype=np.int64)
355+
out = np.take(x64_0, x64_1)
356+
assert_equal(out, tgt)
357+
assert_equal(out.dtype, tgt.dtype)
358+
327359
def test_trace(self):
328360
c = [[1, 2], [3, 4], [5, 6]]
329361
assert_equal(np.trace(c), 5)

0 commit comments

Comments
 (0)