Skip to content

Commit 46f8bb6

Browse files
committed
Use _compat ns
1 parent 37ce670 commit 46f8bb6

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/array_api_extra/_funcs.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
from ._lib import _compat, _utils
1515
from ._lib._compat import (
1616
array_namespace,
17-
device,
1817
is_jax_array,
1918
is_writeable_array,
20-
size,
2119
)
2220
from ._lib._typing import Array, Index
2321

@@ -667,15 +665,15 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
667665
if is_jax_array(x):
668666
# size= is JAX-specific
669667
# https://github.com/data-apis/array-api/issues/883
670-
_, counts = xp.unique_counts(x, size=size(x))
668+
_, counts = xp.unique_counts(x, size=_compat.size(x))
671669
return xp.astype(counts, xp.bool).sum()
672670

673671
_, counts = xp.unique_counts(x)
674-
n = size(counts)
672+
n = _compat.size(counts)
675673
# FIXME https://github.com/data-apis/array-api-compat/pull/231
676674
if n is None or math.isnan(n): # e.g. Dask, ndonnx
677675
return xp.astype(counts, xp.bool).sum()
678-
return xp.asarray(n, device=device(x))
676+
return xp.asarray(n, device=_compat.device(x))
679677

680678

681679
class _AtOp(Enum):

0 commit comments

Comments
 (0)