Skip to content

Commit e3432d7

Browse files
committed
Implements dpctl.tensor.count_nonzero
1 parent 442ba6f commit e3432d7

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@
176176
from ._reduction import (
177177
argmax,
178178
argmin,
179+
count_nonzero,
179180
logsumexp,
180181
max,
181182
min,
@@ -374,4 +375,5 @@
374375
"cumulative_sum",
375376
"nextafter",
376377
"diff",
378+
"count_nonzero",
377379
]

dpctl/tensor/_reduction.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,15 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
773773
default array index data type for the device of ``x``.
774774
"""
775775
return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis)
776+
777+
778+
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
779+
if x.dtype != dpt.bool:
780+
x = dpt.astype(x, dpt.bool, copy=False)
781+
return sum(
782+
x,
783+
axis=axis,
784+
dtype=ti.default_device_index_type(x.sycl_device),
785+
keepdims=keepdims,
786+
out=None,
787+
)

0 commit comments

Comments
 (0)