Skip to content

Commit 8523d8e

Browse files
More tests for advanced indexing
1 parent a966830 commit 8523d8e

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717

18-
# import numpy as np
18+
import numpy as np
1919
import pytest
2020
from helper import get_queue_or_skip
2121

@@ -174,6 +174,22 @@ def test_advanced_slice1():
174174
)
175175

176176

177+
def test_advanced_slice1_negative_strides():
178+
q = get_queue_or_skip()
179+
ii = dpt.asarray([0, 1], sycl_queue=q)
180+
x = dpt.flip(dpt.arange(5, dtype="i4", sycl_queue=q))
181+
y = x[ii]
182+
assert isinstance(y, dpt.usm_ndarray)
183+
assert y.shape == ii.shape
184+
assert y.strides == (1,)
185+
# FIXME, once usm_ndarray.__equal__ is implemented,
186+
# use of asnumpy should be removed
187+
assert _all_equal(
188+
(x[ii[k]] for k in range(ii.shape[0])),
189+
(y[k] for k in range(ii.shape[0])),
190+
)
191+
192+
177193
def test_advanced_slice2():
178194
q = get_queue_or_skip()
179195
ii = dpt.asarray([1, 2], sycl_queue=q)
@@ -363,3 +379,59 @@ def test_advanced_slice13():
363379
assert isinstance(y, dpt.usm_ndarray)
364380
assert y.shape == expected.shape
365381
assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()
382+
383+
384+
def test_integer_indexing_1d():
385+
get_queue_or_skip()
386+
x = dpt.arange(10, dtype="i4")
387+
ind_1d = dpt.asarray([7, 3, 1], dtype="u2")
388+
ind_2d = dpt.asarray([[2, 3, 4], [3, 4, 5], [5, 6, 7]], dtype="i4")
389+
390+
y1 = x[ind_1d]
391+
assert y1.shape == ind_1d.shape
392+
y2 = x[ind_2d]
393+
assert y2.shape == ind_2d.shape
394+
assert (dpt.asnumpy(y1) == np.array([7, 3, 1], dtype="i4")).all()
395+
assert (
396+
dpt.asnumpy(y2)
397+
== np.array([[2, 3, 4], [3, 4, 5], [5, 6, 7]], dtype="i4")
398+
).all()
399+
400+
401+
def test_integer_indexing_2d():
402+
get_queue_or_skip()
403+
n0, n1 = 5, 7
404+
x = dpt.reshape(
405+
dpt.arange(n0 * n1, dtype="i4"),
406+
(
407+
n0,
408+
n1,
409+
),
410+
)
411+
ind0 = dpt.arange(n0)
412+
ind1 = dpt.arange(n1)
413+
414+
y = x[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]]
415+
assert y.dtype == x.dtype
416+
assert (dpt.asnumpy(y) == np.array([[5, 6], [12, 13]])).all()
417+
418+
419+
def test_integer_strided_indexing():
420+
get_queue_or_skip()
421+
n0, n1 = 5, 7
422+
x = dpt.reshape(
423+
dpt.arange(2 * n0 * n1, dtype="i4"),
424+
(
425+
2 * n0,
426+
n1,
427+
),
428+
)
429+
ind0 = dpt.arange(n0)
430+
ind1 = dpt.arange(n1)
431+
432+
z = x[::-2, :]
433+
y = z[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]]
434+
assert y.dtype == x.dtype
435+
zc = dpt.copy(z, order="C")
436+
yc = zc[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]]
437+
assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all()

0 commit comments

Comments
 (0)