Skip to content

Commit b90fc90

Browse files
committed
add basic tests for indexing with NumPy arrays
1 parent e98f0d6 commit b90fc90

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,28 @@ def test_advanced_slice16():
440440
assert isinstance(y, dpt.usm_ndarray)
441441

442442

443+
def test_integer_indexing_numpy_array():
444+
q = get_queue_or_skip()
445+
ii = np.asarray([1, 2])
446+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
447+
y = x[ii]
448+
assert isinstance(y, dpt.usm_ndarray)
449+
assert y.shape == ii.shape
450+
assert dpt.all(dpt.asarray(ii, sycl_queue=q) == y)
451+
452+
453+
def test_boolean_indexing_numpy_array():
454+
q = get_queue_or_skip()
455+
ii = np.asarray(
456+
[False, True, True, False, False, False, False, False, False, False]
457+
)
458+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
459+
y = x[ii]
460+
assert isinstance(y, dpt.usm_ndarray)
461+
assert y.shape == (2,)
462+
assert dpt.all(x[1:3] == y)
463+
464+
443465
def test_boolean_indexing_validation():
444466
get_queue_or_skip()
445467
x = dpt.zeros(10, dtype="i4")

0 commit comments

Comments
 (0)