|
3 | 3 | # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 |
4 | 4 | from __future__ import annotations |
5 | 5 |
|
| 6 | +import math |
6 | 7 | import operator |
7 | 8 | import warnings |
8 | 9 | from collections.abc import Callable |
|
13 | 14 | from ._lib import _compat, _utils |
14 | 15 | from ._lib._compat import ( |
15 | 16 | array_namespace, |
| 17 | + device, |
16 | 18 | is_jax_array, |
17 | 19 | is_writeable_array, |
| 20 | + size, |
18 | 21 | ) |
19 | 22 | from ._lib._typing import Array, Index |
20 | 23 |
|
|
25 | 28 | "create_diagonal", |
26 | 29 | "expand_dims", |
27 | 30 | "kron", |
| 31 | + "nunique", |
28 | 32 | "pad", |
29 | 33 | "setdiff1d", |
30 | 34 | "sinc", |
@@ -638,6 +642,42 @@ def pad( |
638 | 642 | return at(padded, tuple(slices)).set(x) |
639 | 643 |
|
640 | 644 |
|
| 645 | +def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
| 646 | + """ |
| 647 | + Count the number of unique elements in an array. |
| 648 | +
|
| 649 | + Compatible with JAX and Dask, whose laziness would be otherwise |
| 650 | + problematic. |
| 651 | +
|
| 652 | + Parameters |
| 653 | + ---------- |
| 654 | + x : Array |
| 655 | + Input array. |
| 656 | + xp : array_namespace, optional |
| 657 | + The standard-compatible namespace for `x`. Default: infer. |
| 658 | +
|
| 659 | + Returns |
| 660 | + ------- |
| 661 | + array: Scalar integer array |
| 662 | + The number of unique elements in `x`. It can be lazy. |
| 663 | + """ |
| 664 | + if xp is None: |
| 665 | + xp = array_namespace(x) |
| 666 | + |
| 667 | + if is_jax_array(x): |
| 668 | + # size= is JAX-specific |
| 669 | + # https://github.com/data-apis/array-api/issues/883 |
| 670 | + _, counts = xp.unique_counts(x, size=size(x)) |
| 671 | + return xp.astype(counts, xp.bool).sum() |
| 672 | + |
| 673 | + _, counts = xp.unique_counts(x) |
| 674 | + n = size(counts) |
| 675 | + # FIXME https://github.com/data-apis/array-api-compat/pull/231 |
| 676 | + if n is None or math.isnan(n): # e.g. Dask, ndonnx |
| 677 | + return xp.astype(counts, xp.bool).sum() |
| 678 | + return xp.asarray(n, device=device(x)) |
| 679 | + |
| 680 | + |
641 | 681 | class _AtOp(Enum): |
642 | 682 | """Operations for use in `xpx.at`.""" |
643 | 683 |
|
|
0 commit comments