Skip to content

Commit 61917a5

Browse files
More tests for advanced indexing
1 parent d66c494 commit 61917a5

File tree

1 file changed

+209
-3
lines changed

1 file changed

+209
-3
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 209 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
# import numpy as np
19-
# import pytest
19+
import pytest
2020
from helper import get_queue_or_skip
2121

2222
# import dpctl
@@ -144,6 +144,10 @@ def test_basic_slice10():
144144
assert y.strides == (0, n1 * n2, n2, 1)
145145

146146

147+
def _all_equal(it1, it2):
148+
return all(dpt.asnumpy(x) == dpt.asnumpy(y) for x, y in zip(it1, it2))
149+
150+
147151
def test_advanced_slice1():
148152
q = get_queue_or_skip()
149153
ii = dpt.asarray([1, 2], sycl_queue=q)
@@ -154,6 +158,208 @@ def test_advanced_slice1():
154158
assert y.strides == (1,)
155159
# FIXME, once usm_ndarray.__equal__ is implemented,
156160
# use of asnumpy should be removed
157-
assert all(
158-
dpt.asnumpy(x[ii[k]]) == dpt.asnumpy(y[k]) for k in range(ii.shape[0])
161+
assert _all_equal(
162+
(x[ii[k]] for k in range(ii.shape[0])),
163+
(y[k] for k in range(ii.shape[0])),
164+
)
165+
y = x[(ii,)]
166+
assert isinstance(y, dpt.usm_ndarray)
167+
assert y.shape == ii.shape
168+
assert y.strides == (1,)
169+
# FIXME, once usm_ndarray.__equal__ is implemented,
170+
# use of asnumpy should be removed
171+
assert _all_equal(
172+
(x[ii[k]] for k in range(ii.shape[0])),
173+
(y[k] for k in range(ii.shape[0])),
159174
)
175+
176+
177+
def test_advanced_slice2():
178+
q = get_queue_or_skip()
179+
ii = dpt.asarray([1, 2], sycl_queue=q)
180+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
181+
y = x[ii, dpt.newaxis]
182+
assert isinstance(y, dpt.usm_ndarray)
183+
assert y.shape == ii.shape + (1,)
184+
assert y.flags["C"]
185+
186+
187+
def test_advanced_slice3():
188+
q = get_queue_or_skip()
189+
ii = dpt.asarray([1, 2], sycl_queue=q)
190+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
191+
y = x[dpt.newaxis, ii]
192+
assert isinstance(y, dpt.usm_ndarray)
193+
assert y.shape == (1,) + ii.shape
194+
assert y.flags["C"]
195+
196+
197+
def _make_3d(dt, q):
198+
return dpt.reshape(
199+
dpt.arange(3 * 3 * 3, dtype=dt, sycl_queue=q),
200+
(
201+
3,
202+
3,
203+
3,
204+
),
205+
)
206+
207+
208+
def test_advanced_slice4():
209+
q = get_queue_or_skip()
210+
ii = dpt.asarray([1, 2], sycl_queue=q)
211+
x = _make_3d("i4", q)
212+
y = x[ii, ii, ii]
213+
assert isinstance(y, dpt.usm_ndarray)
214+
assert y.shape == ii.shape
215+
assert _all_equal(
216+
(x[ii[k], ii[k], ii[k]] for k in range(ii.shape[0])),
217+
(y[k] for k in range(ii.shape[0])),
218+
)
219+
220+
221+
def test_advanced_slice5():
222+
q = get_queue_or_skip()
223+
ii = dpt.asarray([1, 2], sycl_queue=q)
224+
x = _make_3d("i4", q)
225+
with pytest.raises(IndexError):
226+
x[ii, 0, ii]
227+
228+
229+
def test_advanced_slice6():
230+
q = get_queue_or_skip()
231+
ii = dpt.asarray([1, 2], sycl_queue=q)
232+
x = _make_3d("i4", q)
233+
y = x[:, ii, ii]
234+
assert isinstance(y, dpt.usm_ndarray)
235+
assert y.shape == (
236+
x.shape[0],
237+
ii.shape[0],
238+
)
239+
assert _all_equal(
240+
(
241+
x[i, ii[k], ii[k]]
242+
for i in range(x.shape[0])
243+
for k in range(ii.shape[0])
244+
),
245+
(y[i, k] for i in range(x.shape[0]) for k in range(ii.shape[0])),
246+
)
247+
248+
249+
def test_advanced_slice7():
250+
q = get_queue_or_skip()
251+
mask = dpt.asarray(
252+
[
253+
[[True, True, False], [False, True, True], [True, False, True]],
254+
[[True, False, False], [False, False, True], [False, True, False]],
255+
[[True, True, True], [False, False, False], [False, False, True]],
256+
],
257+
sycl_queue=q,
258+
)
259+
x = _make_3d("i2", q)
260+
y = x[mask]
261+
expected = [0, 1, 4, 5, 6, 8, 9, 14, 16, 18, 19, 20, 26]
262+
assert isinstance(y, dpt.usm_ndarray)
263+
assert y.shape == (len(expected),)
264+
assert all(dpt.asnumpy(y[k]) == expected[k] for k in range(len(expected)))
265+
266+
267+
def test_advanced_slice8():
268+
q = get_queue_or_skip()
269+
mask = dpt.asarray(
270+
[[True, False, False], [False, True, False], [False, True, False]],
271+
sycl_queue=q,
272+
)
273+
x = _make_3d("u2", q)
274+
y = x[mask]
275+
expected = dpt.asarray(
276+
[[0, 1, 2], [12, 13, 14], [21, 22, 23]], sycl_queue=q
277+
)
278+
assert isinstance(y, dpt.usm_ndarray)
279+
assert y.shape == expected.shape
280+
assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()
281+
282+
283+
def test_advanced_slice9():
284+
q = get_queue_or_skip()
285+
mask = dpt.asarray(
286+
[[True, False, False], [False, True, False], [False, True, False]],
287+
sycl_queue=q,
288+
)
289+
x = _make_3d("u4", q)
290+
y = x[:, mask]
291+
expected = dpt.asarray([[0, 4, 7], [9, 13, 16], [18, 22, 25]], sycl_queue=q)
292+
assert isinstance(y, dpt.usm_ndarray)
293+
assert y.shape == expected.shape
294+
assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()
295+
296+
297+
def lin_id(i, j, k):
298+
"""global_linear_id for (3,3,3) range traversed in C-contiguous order"""
299+
return 9 * i + 3 * j + k
300+
301+
302+
def test_advanced_slice10():
303+
q = get_queue_or_skip()
304+
x = _make_3d("u8", q)
305+
i0 = dpt.asarray([0, 1, 1], device=x.device)
306+
i1 = dpt.asarray([1, 1, 2], device=x.device)
307+
i2 = dpt.asarray([2, 0, 1], device=x.device)
308+
y = x[i0, i1, i2]
309+
res_expected = dpt.asarray(
310+
[
311+
lin_id(0, 1, 2),
312+
lin_id(1, 1, 0),
313+
lin_id(1, 2, 1),
314+
],
315+
sycl_queue=q,
316+
)
317+
assert isinstance(y, dpt.usm_ndarray)
318+
assert y.shape == res_expected.shape
319+
assert (dpt.asnumpy(y) == dpt.asnumpy(res_expected)).all()
320+
321+
322+
def test_advanced_slice11():
323+
q = get_queue_or_skip()
324+
x = _make_3d("u8", q)
325+
i0 = dpt.asarray([0, 1, 1], device=x.device)
326+
i2 = dpt.asarray([2, 0, 1], device=x.device)
327+
with pytest.raises(IndexError):
328+
x[i0, :, i2]
329+
330+
331+
def test_advanced_slice12():
332+
q = get_queue_or_skip()
333+
x = _make_3d("u8", q)
334+
i1 = dpt.asarray([1, 1, 2], device=x.device)
335+
i2 = dpt.asarray([2, 0, 1], device=x.device)
336+
y = x[:, dpt.newaxis, i1, i2, dpt.newaxis]
337+
res_expected = dpt.asarray(
338+
[
339+
[[[lin_id(0, 1, 2)], [lin_id(0, 1, 0)], [lin_id(0, 2, 1)]]],
340+
[[[lin_id(1, 1, 2)], [lin_id(1, 1, 0)], [lin_id(1, 2, 1)]]],
341+
[[[lin_id(2, 1, 2)], [lin_id(2, 1, 0)], [lin_id(2, 2, 1)]]],
342+
],
343+
sycl_queue=q,
344+
)
345+
assert isinstance(y, dpt.usm_ndarray)
346+
assert y.shape == res_expected.shape
347+
assert (dpt.asnumpy(y) == dpt.asnumpy(res_expected)).all()
348+
349+
350+
def test_advanced_slice13():
351+
q = get_queue_or_skip()
352+
x = _make_3d("u8", q)
353+
i1 = dpt.asarray([[1], [2]], device=x.device)
354+
i2 = dpt.asarray([[0, 1]], device=x.device)
355+
y = x[i1, i2, 0]
356+
expected = dpt.asarray(
357+
[
358+
[lin_id(1, 0, 0), lin_id(1, 1, 0)],
359+
[lin_id(2, 0, 0), lin_id(2, 1, 0)],
360+
],
361+
device=x.device,
362+
)
363+
assert isinstance(y, dpt.usm_ndarray)
364+
assert y.shape == expected.shape
365+
assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()

0 commit comments

Comments
 (0)