Skip to content

Commit 81ba473

Browse files
Adding basic take, and basic put tests
1 parent ab79d84 commit 81ba473

File tree

1 file changed

+80
-3
lines changed

1 file changed

+80
-3
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
import numpy as np
1919
import pytest
20-
from helper import get_queue_or_skip
20+
from helper import get_queue_or_skip, skip_if_dtype_not_supported
2121

2222
# import dpctl
2323
import dpctl.tensor as dpt
2424

25-
# from helper import skip_if_dtype_not_supported
26-
2725

2826
def test_basic_slice1():
2927
q = get_queue_or_skip()
@@ -435,3 +433,82 @@ def test_integer_strided_indexing():
435433
zc = dpt.copy(z, order="C")
436434
yc = zc[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]]
437435
assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all()
436+
437+
438+
@pytest.mark.parametrize(
439+
"data_dt",
440+
["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8", "e", "f", "d", "F", "D"],
441+
)
442+
@pytest.mark.parametrize(
443+
"ind_dt", ["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8"]
444+
)
445+
def test_take_basic(data_dt, ind_dt):
446+
q = get_queue_or_skip()
447+
skip_if_dtype_not_supported(data_dt, q)
448+
449+
x = dpt.arange(10, dtype=data_dt)
450+
ind = dpt.arange(2, 5, dtype=ind_dt)
451+
y = dpt.take(x, ind)
452+
assert y.dtype == x.dtype
453+
assert (dpt.asnumpy(y) == np.arange(2, 5, dtype=data_dt)).all()
454+
455+
456+
@pytest.mark.parametrize(
457+
"data_dt",
458+
["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8", "e", "f", "d", "F", "D"],
459+
)
460+
@pytest.mark.parametrize(
461+
"ind_dt", ["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8"]
462+
)
463+
def test_put_basic(data_dt, ind_dt):
464+
q = get_queue_or_skip()
465+
skip_if_dtype_not_supported(data_dt, q)
466+
467+
x = dpt.arange(10, dtype=data_dt)
468+
ind = dpt.arange(2, 5, dtype=ind_dt)
469+
val = dpt.ones(3, dtype=data_dt)
470+
dpt.put(x, ind, val)
471+
assert (
472+
dpt.asnumpy(x)
473+
== np.array([0, 1, 1, 1, 1, 5, 6, 7, 8, 9], dtype=data_dt)
474+
).all()
475+
476+
477+
def test_take_basic_axis():
478+
get_queue_or_skip()
479+
480+
n0, n1 = 5, 7
481+
x = dpt.reshape(
482+
dpt.arange(n0 * n1, dtype="i4"),
483+
(
484+
n0,
485+
n1,
486+
),
487+
)
488+
ind = dpt.arange(2, 4)
489+
y0 = dpt.take(x, ind, axis=0)
490+
y1 = dpt.take(x, ind, axis=1)
491+
assert y0.shape == (2, n1)
492+
assert y1.shape == (n0, 2)
493+
494+
495+
def test_put_basic_axis():
496+
get_queue_or_skip()
497+
498+
n0, n1 = 5, 7
499+
x = dpt.reshape(
500+
dpt.arange(n0 * n1, dtype="i4"),
501+
(
502+
n0,
503+
n1,
504+
),
505+
)
506+
ind = dpt.arange(2, 4)
507+
v0 = dpt.zeros((2, n1), dtype=x.dtype)
508+
v1 = dpt.zeros((n0, 2), dtype=x.dtype)
509+
dpt.put(x, ind, v0, axis=0)
510+
dpt.put(x, ind, v1, axis=1)
511+
expected = np.arange(n0 * n1, dtype="i4").reshape((n0, n1))
512+
expected[[2, 3], :] = 0
513+
expected[:, [2, 3]] = 0
514+
assert (expected == dpt.asnumpy(x)).all()

0 commit comments

Comments
 (0)