|
5 | 5 | import numpy as np
|
6 | 6 | import pytest
|
7 | 7 |
|
8 |
| -from .. import ones, arange, reshape, asarray, result_type, all, equal |
| 8 | +from .. import ones, arange, reshape, asarray, result_type, all, equal, stack |
9 | 9 | from .._array_object import Array, CPU_DEVICE, Device
|
10 | 10 | from .._dtypes import (
|
11 | 11 | _all_dtypes,
|
@@ -101,33 +101,37 @@ def test_validate_index():
|
101 | 101 | assert_raises(IndexError, lambda: a[idx])
|
102 | 102 |
|
103 | 103 |
|
104 |
| -def test_indexing_arrays(): |
| 104 | +# @pytest.mark.parametrize("device", ["CPU_DEVICE", "device1", "device2"]) |
| 105 | +def test_indexing_arrays(device='device1'): |
105 | 106 | # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
|
| 107 | + device = Device(device) |
106 | 108 |
|
107 | 109 | # 1D array
|
108 | 110 | a = arange(5)
|
109 |
| - idx = asarray([1, 0, 1, 2, -1]) |
| 111 | + idx = asarray([1, 0, 1, 2, -1], device=device) |
110 | 112 | a_idx = a[idx]
|
111 | 113 |
|
112 |
| - a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) |
| 114 | + a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) |
113 | 115 | assert all(a_idx == a_idx_loop)
|
| 116 | + assert a_idx.shape == idx.shape |
114 | 117 |
|
115 | 118 | # setitem with arrays is not allowed
|
116 | 119 | with assert_raises(IndexError):
|
117 | 120 | a[idx] = 42
|
118 | 121 |
|
119 | 122 | # mixed array and integer indexing
|
120 |
| - a = reshape(arange(3*4), (3, 4)) |
121 |
| - idx = asarray([1, 0, 1, 2, -1]) |
| 123 | + a = reshape(arange(3*4, device=device), (3, 4)) |
| 124 | + idx = asarray([1, 0, 1, 2, -1], device=device) |
122 | 125 | a_idx = a[idx, 1]
|
123 |
| - |
124 |
| - a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) |
| 126 | + a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) |
125 | 127 | assert all(a_idx == a_idx_loop)
|
| 128 | + assert a_idx.shape == idx.shape |
126 | 129 |
|
127 | 130 | # index with two arrays
|
128 | 131 | a_idx = a[idx, idx]
|
129 |
| - a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) |
| 132 | + a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) |
130 | 133 | assert all(a_idx == a_idx_loop)
|
| 134 | + assert a_idx.shape == a_idx.shape |
131 | 135 |
|
132 | 136 | # setitem with arrays is not allowed
|
133 | 137 | with assert_raises(IndexError):
|
|
0 commit comments