Skip to content

Commit 9a4430c

Browse files
committed
Move compress tests into a TestCompress class
1 parent 5818d10 commit 9a4430c

File tree

1 file changed

+44
-46
lines changed

1 file changed

+44
-46
lines changed

dpnp/tests/test_indexing.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,49 +1337,47 @@ def test_error(self):
13371337
dpnp.select([x1], [x1])
13381338

13391339

1340-
def test_compress_basic():
1341-
a = dpnp.arange(16).reshape(4, 4)
1342-
condition = dpnp.asarray([True, False, True])
1343-
r = dpnp.compress(condition, a, axis=0)
1344-
assert_array_equal(r[0], a[0])
1345-
assert_array_equal(r[1], a[2])
1346-
1347-
1348-
@pytest.mark.parametrize("dtype", get_all_dtypes())
1349-
def test_compress_condition_all_dtypes(dtype):
1350-
a = dpnp.arange(10, dtype="i4")
1351-
condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1352-
r = dpnp.compress(condition, a)
1353-
assert_array_equal(r, a[1::2])
1354-
1355-
1356-
def test_compress_invalid_out_errors():
1357-
q1 = dpctl.SyclQueue()
1358-
q2 = dpctl.SyclQueue()
1359-
a = dpnp.ones(10, dtype="i4", sycl_queue=q1)
1360-
condition = dpnp.asarray([True], sycl_queue=q1)
1361-
out_bad_shape = dpnp.empty_like(a)
1362-
with pytest.raises(ValueError):
1363-
dpnp.compress(condition, a, out=out_bad_shape)
1364-
out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2)
1365-
with pytest.raises(ExecutionPlacementError):
1366-
dpnp.compress(condition, a, out=out_bad_queue)
1367-
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
1368-
with pytest.raises(TypeError):
1369-
dpnp.compress(condition, a, out=out_bad_dt)
1370-
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
1371-
out_read_only.flags.writable = False
1372-
with pytest.raises(ValueError):
1373-
dpnp.compress(condition, a, out=out_read_only)
1374-
1375-
1376-
def test_compress_empty_axis():
1377-
a = dpnp.ones((10, 0, 5), dtype="i4")
1378-
condition = [True, False, True]
1379-
r = dpnp.compress(condition, a, axis=0)
1380-
assert r.shape == (2, 0, 5)
1381-
# empty take from empty axis is permitted
1382-
assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5)
1383-
# non-empty take from empty axis raises IndexError
1384-
with pytest.raises(IndexError):
1385-
dpnp.compress(condition, a, axis=1)
1340+
class TestCompress:
1341+
def test_compress_basic(self):
1342+
a = dpnp.arange(16).reshape(4, 4)
1343+
condition = dpnp.asarray([True, False, True])
1344+
r = dpnp.compress(condition, a, axis=0)
1345+
assert_array_equal(r[0], a[0])
1346+
assert_array_equal(r[1], a[2])
1347+
1348+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1349+
def test_compress_condition_all_dtypes(self, dtype):
1350+
a = dpnp.arange(10, dtype="i4")
1351+
condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1352+
r = dpnp.compress(condition, a)
1353+
assert_array_equal(r, a[1::2])
1354+
1355+
def test_compress_invalid_out_errors(self):
1356+
q1 = dpctl.SyclQueue()
1357+
q2 = dpctl.SyclQueue()
1358+
a = dpnp.ones(10, dtype="i4", sycl_queue=q1)
1359+
condition = dpnp.asarray([True], sycl_queue=q1)
1360+
out_bad_shape = dpnp.empty_like(a)
1361+
with pytest.raises(ValueError):
1362+
dpnp.compress(condition, a, out=out_bad_shape)
1363+
out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2)
1364+
with pytest.raises(ExecutionPlacementError):
1365+
dpnp.compress(condition, a, out=out_bad_queue)
1366+
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
1367+
with pytest.raises(TypeError):
1368+
dpnp.compress(condition, a, out=out_bad_dt)
1369+
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
1370+
out_read_only.flags.writable = False
1371+
with pytest.raises(ValueError):
1372+
dpnp.compress(condition, a, out=out_read_only)
1373+
1374+
def test_compress_empty_axis(self):
1375+
a = dpnp.ones((10, 0, 5), dtype="i4")
1376+
condition = [True, False, True]
1377+
r = dpnp.compress(condition, a, axis=0)
1378+
assert r.shape == (2, 0, 5)
1379+
# empty take from empty axis is permitted
1380+
assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5)
1381+
# non-empty take from empty axis raises IndexError
1382+
with pytest.raises(IndexError):
1383+
dpnp.compress(condition, a, axis=1)

0 commit comments

Comments
 (0)