Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
create_diagonal
expand_dims
kron
nunique
setdiff1d
sinc
```
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
create_diagonal,
expand_dims,
kron,
nunique,
pad,
setdiff1d,
sinc,
Expand All @@ -23,6 +24,7 @@
"create_diagonal",
"expand_dims",
"kron",
"nunique",
"pad",
"setdiff1d",
"sinc",
Expand Down
40 changes: 40 additions & 0 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
from __future__ import annotations

import math
import operator
import warnings
from collections.abc import Callable
Expand All @@ -13,8 +14,10 @@
from ._lib import _compat, _utils
from ._lib._compat import (
array_namespace,
device,
is_jax_array,
is_writeable_array,
size,
)
from ._lib._typing import Array, Index

Expand All @@ -25,6 +28,7 @@
"create_diagonal",
"expand_dims",
"kron",
"nunique",
"pad",
"setdiff1d",
"sinc",
Expand Down Expand Up @@ -638,6 +642,42 @@ def pad(
return at(padded, tuple(slices)).set(x)


def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Count the number of unique elements in an array.

Compatible with JAX and Dask, whose laziness would be otherwise
problematic.

Parameters
----------
x : Array
Input array.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array: Scalar integer array
The number of unique elements in `x`. It can be lazy.
"""
if xp is None:
xp = array_namespace(x)

if is_jax_array(x):
# size= is JAX-specific
# https://github.com/data-apis/array-api/issues/883
_, counts = xp.unique_counts(x, size=size(x))
return xp.astype(counts, xp.bool).sum()

_, counts = xp.unique_counts(x)
n = size(counts)
# FIXME https://github.com/data-apis/array-api-compat/pull/231
if n is None or math.isnan(n): # e.g. Dask, ndonnx
return xp.astype(counts, xp.bool).sum()
return xp.asarray(n, device=device(x))


class _AtOp(Enum):
"""Operations for use in `xpx.at`."""

Expand Down
19 changes: 19 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_diagonal,
expand_dims,
kron,
nunique,
pad,
setdiff1d,
sinc,
Expand Down Expand Up @@ -448,3 +449,21 @@ def test_list_of_tuples_width(self, xp: ModuleType):

padded = pad(a, [(1, 0), (0, 0)])
assert padded.shape == (4, 4)


class TestNUnique:
def test_simple(self, xp: ModuleType):
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
xp_assert_equal(nunique(a), xp.asarray(3))

def test_empty(self, xp: ModuleType):
a = xp.asarray([])
xp_assert_equal(nunique(a), xp.asarray(0))

def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray(0.0, device=device)
assert get_device(nunique(a)) == device

def test_xp(self, xp: ModuleType):
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))
Loading