@@ -663,7 +663,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
663663@export
664664def unique (ar : ArrayLike , return_index : bool = False , return_inverse : bool = False ,
665665 return_counts : bool = False , axis : int | None = None ,
666- * , equal_nan : bool = True , size : int | None = None , fill_value : ArrayLike | None = None ):
666+ * , equal_nan : bool = True , size : int | None = None , fill_value : ArrayLike | None = None ,
667+ sorted : bool = True ):
667668 """Return the unique values from an array.
668669
669670 JAX implementation of :func:`numpy.unique`.
@@ -686,6 +687,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
686687 unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
687688 fill_value: when ``size`` is specified and there are fewer than the indicated number of
688689 elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
690+ sorted: unused by JAX.
689691
690692 Returns:
691693 An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``,
@@ -830,6 +832,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
830832 >>> print(counts)
831833 [2 1]
832834 """
835+ # TODO: Investigate if it's possible that we could save some work in
836+ # _unique_sorted_mask when sorting is not requested, but that would require
837+ # refactoring the implementation a bit.
838+ del sorted # unused
833839 arr = ensure_arraylike ("unique" , ar )
834840 if size is None :
835841 arr = core .concrete_or_error (None , arr ,
0 commit comments