Skip to content

Commit 7a46b32

Browse files
committed
Merge branch 'issue-25607-take-32bit-bug' of https://github.com/JuliaPoo/numpy into issue-25607-take-32bit-bug
2 parents 540cdb3 + ba241ef commit 7a46b32

File tree

1 file changed

+8
-30
lines changed

1 file changed

+8
-30
lines changed

numpy/_core/tests/test_numeric.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -324,37 +324,15 @@ def test_take(self):
324324
out = np.take(a, indices)
325325
assert_equal(out, tgt)
326326

327-
# take 32 64
328-
x32 = np.array([1, 2, 3, 4, 5], dtype=np.int32)
329-
x64 = np.array([0, 2, 2, 3], dtype=np.int64)
330-
tgt = np.array([1, 3, 3, 4], dtype=np.int32)
331-
out = np.take(x32, x64)
332-
assert_equal(out, tgt)
333-
assert_equal(out.dtype, tgt.dtype)
334-
335-
# take 64 32
336-
x64 = np.array([1, 2, 3, 4, 5], dtype=np.int64)
337-
x32 = np.array([0, 2, 2, 3], dtype=np.int32)
338-
tgt = np.array([1, 3, 3, 4], dtype=np.int64)
339-
out = np.take(x64, x32)
340-
assert_equal(out, tgt)
341-
assert_equal(out.dtype, tgt.dtype)
327+
pairs = [(np.int32, np.int32), (np.int32, np.int64), (np.int64, np.int32), (np.int64, np.int64)]
328+
for array_type, indices_type in pairs:
329+
x = np.array([1, 2, 3, 4, 5], dtype=array_type)
330+
ind = np.array([0, 2, 2, 3], dtype=indices_type )
331+
tgt = np.array([1, 3, 3, 4], dtype=array_type)
332+
out = np.take(x, ind)
333+
assert_equal(out, tgt)
334+
assert_equal(out.dtype, tgt.dtype)
342335

343-
# take 32 32
344-
x32_0 = np.array([1, 2, 3, 4, 5], dtype=np.int32)
345-
x32_1 = np.array([0, 2, 2, 3], dtype=np.int32)
346-
tgt = np.array([1, 3, 3, 4], dtype=np.int32)
347-
out = np.take(x32_0, x32_1)
348-
assert_equal(out, tgt)
349-
assert_equal(out.dtype, tgt.dtype)
350-
351-
# take 64 64
352-
x64_0 = np.array([1, 2, 3, 4, 5], dtype=np.int64)
353-
x64_1 = np.array([0, 2, 2, 3], dtype=np.int64)
354-
tgt = np.array([1, 3, 3, 4], dtype=np.int64)
355-
out = np.take(x64_0, x64_1)
356-
assert_equal(out, tgt)
357-
assert_equal(out.dtype, tgt.dtype)
358336

359337
def test_trace(self):
360338
c = [[1, 2], [3, 4], [5, 6]]

0 commit comments

Comments
 (0)