Skip to content

Commit c39dcf4

Browse files
authored
Make count_nonzero work with array-api-strict (#806)
1 parent c33ee09 commit c39dcf4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

cubed/array_api/searching_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ def argmin(x, /, *, axis=None, keepdims=False, split_every=None):
4141

4242
def count_nonzero(x, /, *, axis=None, keepdims=False, split_every=None):
4343
dtype = nxp.__array_namespace_info__().default_dtypes(device=x.device)["indexing"]
44+
x_nonzero = astype(x, nxp.bool)
4445
return sum(
45-
astype(x, nxp.bool),
46+
astype(x_nonzero, dtype),
4647
axis=axis,
4748
dtype=dtype,
4849
keepdims=keepdims,

0 commit comments

Comments
 (0)