Skip to content

Commit 6c5ef1a

Browse files
committed
Update jnp.unique to support upstream interface changes.
1 parent 96dce0b commit 6c5ef1a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

jax/_src/numpy/setops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
663663
@export
664664
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
665665
return_counts: bool = False, axis: int | None = None,
666-
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
666+
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None,
667+
sorted: bool = True):
667668
"""Return the unique values from an array.
668669
669670
JAX implementation of :func:`numpy.unique`.
@@ -686,6 +687,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
686687
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
687688
fill_value: when ``size`` is specified and there are fewer than the indicated number of
688689
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
690+
sorted: unused by JAX.
689691
690692
Returns:
691693
An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``,
@@ -830,6 +832,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
830832
>>> print(counts)
831833
[2 1]
832834
"""
835+
# TODO: Investigate if it's possible that we could save some work in
836+
# _unique_sorted_mask when sorting is not requested, but that would require
837+
# refactoring the implementation a bit.
838+
del sorted # unused
833839
arr = ensure_arraylike("unique", ar)
834840
if size is None:
835841
arr = core.concrete_or_error(None, arr,

tests/lax_numpy_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,7 +2127,7 @@ def testUniqueValues(self, shape, dtype):
21272127
if jtu.numpy_version() < (2, 0, 0):
21282128
np_fun = np.unique
21292129
else:
2130-
np_fun = np.unique_values
2130+
np_fun = lambda *args: np.sort(np.unique_values(*args))
21312131
self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker)
21322132

21332133
@jtu.sample_product(
@@ -6452,9 +6452,10 @@ def testWrappedSignaturesMatch(self):
64526452
'compress': ['size', 'fill_value'],
64536453
'einsum': ['subscripts', 'precision'],
64546454
'einsum_path': ['subscripts'],
6455+
'fill_diagonal': ['inplace'],
64556456
'load': ['args', 'kwargs'],
64566457
'take_along_axis': ['mode', 'fill_value'],
6457-
'fill_diagonal': ['inplace'],
6458+
'unique': ['size', 'fill_value'],
64586459
}
64596460

64606461
mismatches = {}

0 commit comments

Comments
 (0)