Skip to content

Commit 49bfa0e

Browse files
committed
Add tests for compress
1 parent 2381662 commit 49bfa0e

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

dpnp/tests/test_indexing.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import functools
22

3+
import dpctl
34
import dpctl.tensor as dpt
45
import numpy
56
import pytest
67
from dpctl.tensor._numpy_helper import AxisError
8+
from dpctl.utils import ExecutionPlacementError
79
from numpy.testing import (
810
assert_,
911
assert_array_equal,
@@ -1333,3 +1335,51 @@ def test_error(self):
13331335
dpnp.select([x0], [x1], default=x1)
13341336
with pytest.raises(TypeError):
13351337
dpnp.select([x1], [x1])
1338+
1339+
1340+
def test_compress_basic():
1341+
a = dpnp.arange(16).reshape(4, 4)
1342+
condition = dpnp.asarray([True, False, True])
1343+
r = dpnp.compress(condition, a, axis=0)
1344+
assert_array_equal(r[0], a[0])
1345+
assert_array_equal(r[1], a[2])
1346+
1347+
1348+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1349+
def test_compress_condition_all_dtypes(dtype):
1350+
a = dpnp.arange(10, dtype="i4")
1351+
condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1352+
r = dpnp.compress(condition, a)
1353+
assert_array_equal(r, a[1::2])
1354+
1355+
1356+
def test_compress_invalid_out_errors():
1357+
q1 = dpctl.SyclQueue()
1358+
q2 = dpctl.SyclQueue()
1359+
a = dpnp.ones(10, dtype="i4", sycl_queue=q1)
1360+
condition = dpnp.asarray([True], sycl_queue=q1)
1361+
out_bad_shape = dpnp.empty_like(a)
1362+
with pytest.raises(ValueError):
1363+
dpnp.compress(condition, a, out=out_bad_shape)
1364+
out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2)
1365+
with pytest.raises(ExecutionPlacementError):
1366+
dpnp.compress(condition, a, out=out_bad_queue)
1367+
out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1)
1368+
with pytest.raises(ValueError):
1369+
dpnp.compress(condition, a, out=out_bad_dt)
1370+
out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1)
1371+
out_read_only.flags.writable = False
1372+
with pytest.raises(ValueError):
1373+
dpnp.compress(condition, a, out=out_read_only)
1374+
1375+
1376+
def test_compress_empty_axis():
1377+
a = dpnp.ones((10, 0, 5), dtype="i4")
1378+
condition = [True, False, True]
1379+
r = dpnp.compress(condition, a, axis=0)
1380+
assert r.shape == (2, 0, 5)
1381+
# empty take from empty axis is permitted
1382+
assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5)
1383+
# non-empty take from empty axis raises IndexError
1384+
with pytest.raises(IndexError):
1385+
dpnp.compress(condition, a, axis=1)

0 commit comments

Comments
 (0)