Skip to content

Commit 234d81e

Browse files
committed
Adds a basic test for count_nonzero
1 parent 9fb94d5 commit 234d81e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from numpy.testing import assert_allclose
2222

2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._tensor_impl import default_device_index_type
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526
from dpctl.utils import ExecutionPlacementError
2627

@@ -669,3 +670,21 @@ def test_reduction_out_kwarg_arg_validation():
669670
keepdims=True,
670671
out=dpt.empty_like(out_wrong_keepdims, dtype=ind_dt),
671672
)
673+
674+
675+
@pytest.mark.parametrize("dt", _all_dtypes)
676+
def test_count_nonzero(dt):
677+
q = get_queue_or_skip()
678+
skip_if_dtype_not_supported(dt, q)
679+
680+
expected_dt = default_device_index_type(q.sycl_device)
681+
682+
x = dpt.ones(10, dtype=dt, sycl_queue=q)
683+
res = dpt.count_nonzero(x)
684+
assert res == 10
685+
assert x.dtype == expected_dt
686+
687+
x[3:6] = 0
688+
res = dpt.count_nonzero(x)
689+
assert res == 7
690+
assert x.dtype == expected_dt

0 commit comments

Comments
 (0)