Skip to content

Commit 132728d

Browse files
author
Vahid Tavanashad
committed
add new tests
1 parent 45d8732 commit 132728d

File tree

2 files changed

+57
-21
lines changed

2 files changed

+57
-21
lines changed

tests/test_sycl_queue.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,18 +2048,34 @@ def test_concat_stack(func, data1, data2, device):
20482048
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)
20492049

20502050

2051-
@pytest.mark.parametrize(
2052-
"device",
2053-
valid_devices,
2054-
ids=[device.filter_string for device in valid_devices],
2055-
)
2056-
@pytest.mark.parametrize(
2057-
"obj", [slice(None, None, 2), 3, [2, 3]], ids=["slice", "int", "list"]
2058-
)
2059-
def test_delete(device, obj):
2060-
x = dpnp.arange(5, device=device)
2061-
result = dpnp.delete(x, obj)
2062-
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
2051+
class TestDelete:
2052+
@pytest.mark.parametrize(
2053+
"obj",
2054+
[slice(None, None, 2), 3, [2, 3]],
2055+
ids=["slice", "scalar", "list"],
2056+
)
2057+
@pytest.mark.parametrize(
2058+
"device",
2059+
valid_devices,
2060+
ids=[device.filter_string for device in valid_devices],
2061+
)
2062+
def test_delete(self, obj, device):
2063+
x = dpnp.arange(5, device=device)
2064+
result = dpnp.delete(x, obj)
2065+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
2066+
2067+
@pytest.mark.parametrize(
2068+
"device",
2069+
valid_devices,
2070+
ids=[device.filter_string for device in valid_devices],
2071+
)
2072+
def test_obj_ndarray(self, device):
2073+
x = dpnp.arange(5, device=device)
2074+
y = dpnp.array([1, 4], device=device)
2075+
result = dpnp.delete(x, y)
2076+
2077+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
2078+
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
20632079

20642080

20652081
@pytest.mark.parametrize(

tests/test_usm_type.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -831,16 +831,36 @@ def test_cond(usm_type, p):
831831
assert result.usm_type == usm_type
832832

833833

834-
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
835-
@pytest.mark.parametrize(
836-
"obj", [slice(None, None, 2), 3, [2, 3]], ids=["slice", "int", "list"]
837-
)
838-
def test_delete(usm_type, obj):
839-
x = dp.arange(5, usm_type=usm_type)
840-
result = dp.delete(x, obj)
834+
class TestDelete:
835+
@pytest.mark.parametrize(
836+
"obj",
837+
[slice(None, None, 2), 3, [2, 3]],
838+
ids=["slice", "scalar", "list"],
839+
)
840+
@pytest.mark.parametrize(
841+
"usm_type", list_of_usm_types, ids=list_of_usm_types
842+
)
843+
def test_delete(self, obj, usm_type):
844+
x = dp.arange(5, usm_type=usm_type)
845+
result = dp.delete(x, obj)
841846

842-
assert x.usm_type == usm_type
843-
assert result.usm_type == usm_type
847+
assert x.usm_type == usm_type
848+
assert result.usm_type == usm_type
849+
850+
@pytest.mark.parametrize(
851+
"usm_type_x", list_of_usm_types, ids=list_of_usm_types
852+
)
853+
@pytest.mark.parametrize(
854+
"usm_type_y", list_of_usm_types, ids=list_of_usm_types
855+
)
856+
def test_obj_ndarray(self, usm_type_x, usm_type_y):
857+
x = dp.arange(5, usm_type=usm_type_x)
858+
y = dp.array([1, 4], usm_type=usm_type_y)
859+
z = dp.delete(x, y)
860+
861+
assert x.usm_type == usm_type_x
862+
assert y.usm_type == usm_type_y
863+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
844864

845865

846866
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)

0 commit comments

Comments
 (0)