Skip to content

Commit b4264a8

Browse files
committed
Use NumPy in compress tests
1 parent 9a4430c commit b4264a8

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

dpnp/tests/test_indexing.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,18 +1339,24 @@ def test_error(self):
13391339

13401340
class TestCompress:
13411341
def test_compress_basic(self):
1342+
conditions = [True, False, True]
1343+
a_np = numpy.arange(16).reshape(4, 4)
13421344
a = dpnp.arange(16).reshape(4, 4)
1343-
condition = dpnp.asarray([True, False, True])
1344-
r = dpnp.compress(condition, a, axis=0)
1345-
assert_array_equal(r[0], a[0])
1346-
assert_array_equal(r[1], a[2])
1345+
cond_np = numpy.array(conditions)
1346+
cond = dpnp.array(conditions)
1347+
expected = numpy.compress(cond_np, a_np, axis=0)
1348+
result = dpnp.compress(cond, a, axis=0)
1349+
assert_array_equal(expected, result)
13471350

13481351
@pytest.mark.parametrize("dtype", get_all_dtypes())
13491352
def test_compress_condition_all_dtypes(self, dtype):
1353+
a_np = numpy.arange(10, dtype="i4")
13501354
a = dpnp.arange(10, dtype="i4")
1351-
condition = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1352-
r = dpnp.compress(condition, a)
1353-
assert_array_equal(r, a[1::2])
1355+
cond_np = numpy.tile(numpy.asarray([0, 1], dtype=dtype), 5)
1356+
cond = dpnp.tile(dpnp.asarray([0, 1], dtype=dtype), 5)
1357+
expected = numpy.compress(cond_np, a_np)
1358+
result = dpnp.compress(cond, a)
1359+
assert_array_equal(expected, result)
13541360

13551361
def test_compress_invalid_out_errors(self):
13561362
q1 = dpctl.SyclQueue()

0 commit comments

Comments
 (0)