Skip to content

Commit 5524b7f

Browse files
committed
Implements dpctl.tensor.count_nonzero
1 parent f975dbb commit 5524b7f

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
from ._reduction import (
176176
argmax,
177177
argmin,
178+
count_nonzero,
178179
logsumexp,
179180
max,
180181
min,
@@ -372,4 +373,5 @@
372373
"cumulative_prod",
373374
"cumulative_sum",
374375
"diff",
376+
"count_nonzero",
375377
]

dpctl/tensor/_reduction.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,15 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
772772
default array index data type for the device of ``x``.
773773
"""
774774
return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis)
775+
776+
777+
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
778+
if x.dtype != dpt.bool:
779+
x = dpt.astype(x, dpt.bool, copy=False)
780+
return sum(
781+
x,
782+
axis=axis,
783+
dtype=ti.default_device_index_type(x.sycl_device),
784+
keepdims=keepdims,
785+
out=None,
786+
)

0 commit comments

Comments
 (0)