|
1 | 1 | import functools |
2 | 2 |
|
| 3 | +import dpctl |
3 | 4 | import dpctl.tensor as dpt |
4 | 5 | import numpy |
5 | 6 | import pytest |
6 | 7 | from dpctl.tensor._numpy_helper import AxisError |
| 8 | +from dpctl.utils import ExecutionPlacementError |
7 | 9 | from numpy.testing import ( |
8 | 10 | assert_, |
9 | 11 | assert_array_equal, |
@@ -1333,3 +1335,51 @@ def test_error(self): |
1333 | 1335 | dpnp.select([x0], [x1], default=x1) |
1334 | 1336 | with pytest.raises(TypeError): |
1335 | 1337 | dpnp.select([x1], [x1]) |
| 1338 | + |
| 1339 | + |
| 1340 | +def test_compress_basic(): |
| 1341 | + a = dpnp.arange(16).reshape(4, 4) |
| 1342 | + condition = dpnp.asarray([True, False, True]) |
| 1343 | + r = dpnp.compress(condition, a, axis=0) |
| 1344 | + assert_array_equal(r[0], a[0]) |
| 1345 | + assert_array_equal(r[1], a[2]) |
| 1346 | + |
| 1347 | + |
| 1348 | +@pytest.mark.parametrize("dtype", get_all_dtypes()) |
| 1349 | +def test_compress_condition_all_dtypes(dtype): |
| 1350 | + a = dpnp.arange(10, dtype="i4") |
| 1351 | + condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5) |
| 1352 | + r = dpnp.compress(condition, a) |
| 1353 | + assert_array_equal(r, a[1::2]) |
| 1354 | + |
| 1355 | + |
| 1356 | +def test_compress_invalid_out_errors(): |
| 1357 | + q1 = dpctl.SyclQueue() |
| 1358 | + q2 = dpctl.SyclQueue() |
| 1359 | + a = dpnp.ones(10, dtype="i4", sycl_queue=q1) |
| 1360 | + condition = dpnp.asarray([True], sycl_queue=q1) |
| 1361 | + out_bad_shape = dpnp.empty_like(a) |
| 1362 | + with pytest.raises(ValueError): |
| 1363 | + dpnp.compress(condition, a, out=out_bad_shape) |
| 1364 | + out_bad_queue = dpnp.empty(1, dtype="i4", sycl_queue=q2) |
| 1365 | + with pytest.raises(ExecutionPlacementError): |
| 1366 | + dpnp.compress(condition, a, out=out_bad_queue) |
| 1367 | + out_bad_dt = dpnp.empty(1, dtype="i8", sycl_queue=q1) |
| 1368 | + with pytest.raises(ValueError): |
| 1369 | + dpnp.compress(condition, a, out=out_bad_dt) |
| 1370 | + out_read_only = dpnp.empty(1, dtype="i4", sycl_queue=q1) |
| 1371 | + out_read_only.flags.writable = False |
| 1372 | + with pytest.raises(ValueError): |
| 1373 | + dpnp.compress(condition, a, out=out_read_only) |
| 1374 | + |
| 1375 | + |
| 1376 | +def test_compress_empty_axis(): |
| 1377 | + a = dpnp.ones((10, 0, 5), dtype="i4") |
| 1378 | + condition = [True, False, True] |
| 1379 | + r = dpnp.compress(condition, a, axis=0) |
| 1380 | + assert r.shape == (2, 0, 5) |
| 1381 | + # empty take from empty axis is permitted |
| 1382 | + assert dpnp.compress([False], a, axis=1).shape == (10, 0, 5) |
| 1383 | + # non-empty take from empty axis raises IndexError |
| 1384 | + with pytest.raises(IndexError): |
| 1385 | + dpnp.compress(condition, a, axis=1) |
0 commit comments