@@ -10154,9 +10154,69 @@ def sort_complex(a: ArrayLike) -> Array:
1015410154 a = lax .sort (asarray (a ))
1015510155 return lax .convert_element_type (a , dtypes .to_complex_dtype (a .dtype ))
1015610156
10157- @ util . implements ( np . lexsort )
10157+
1015810158@partial (jit , static_argnames = ('axis' ,))
1015910159def lexsort (keys : Array | np .ndarray | Sequence [ArrayLike ], axis : int = - 1 ) -> Array :
10160+ """Sort a sequence of keys in lexicographic order.
10161+
10162+ JAX implementation of :func:`numpy.lexsort`.
10163+
10164+ Args:
10165+ keys: a sequence of arrays to sort; all arrays must have the same shape.
10166+ The last key in the sequence is used as the primary key.
10167+ axis: the axis along which to sort (default: -1).
10168+
10169+ Returns:
10170+ An array of integers of shape ``keys[0].shape`` giving the indices of the
10171+ entries in lexicographically-sorted order.
10172+
10173+ See also:
10174+ - :func:`jax.numpy.argsort`: sort a single entry by index.
10175+ - :func:`jax.lax.sort`: direct XLA sorting API.
10176+
10177+ Examples:
10178+ :func:`lexsort` with a single key is equivalent to :func:`argsort`:
10179+
10180+ >>> key1 = jnp.array([4, 2, 3, 2, 5])
10181+ >>> jnp.lexsort([key1])
10182+ Array([1, 3, 2, 0, 4], dtype=int32)
10183+ >>> jnp.argsort(key1)
10184+ Array([1, 3, 2, 0, 4], dtype=int32)
10185+
10186+ With multiple keys, :func:`lexsort` uses the last key as the primary key:
10187+
10188+ >>> key2 = jnp.array([2, 1, 1, 2, 2])
10189+ >>> jnp.lexsort([key1, key2])
10190+ Array([1, 2, 3, 0, 4], dtype=int32)
10191+
10192+ The meaning of the indices become more clear when printing the sorted keys:
10193+
10194+ >>> indices = jnp.lexsort([key1, key2])
10195+ >>> print(f"{key1[indices]}\\ n{key2[indices]}")
10196+ [2 3 2 4 5]
10197+ [1 1 2 2 2]
10198+
10199+ Notice that the elements of ``key2`` appear in order, and within the sequences
10200+ of duplicated values the corresponding elements of ```key1`` appear in order.
10201+
10202+ For multi-dimensional inputs, :func:`lexsort` defaults to sorting along the
10203+ last axis:
10204+
10205+ >>> key1 = jnp.array([[2, 4, 2, 3],
10206+ ... [3, 1, 2, 2]])
10207+ >>> key2 = jnp.array([[1, 2, 1, 3],
10208+ ... [2, 1, 2, 1]])
10209+ >>> jnp.lexsort([key1, key2])
10210+ Array([[0, 2, 1, 3],
10211+ [1, 3, 2, 0]], dtype=int32)
10212+
10213+ A different sort axis can be chosen using the ``axis`` keyword; here we sort
10214+ along the leading axis:
10215+
10216+ >>> jnp.lexsort([key1, key2], axis=0)
10217+ Array([[0, 1, 0, 1],
10218+ [1, 0, 1, 0]], dtype=int32)
10219+ """
1016010220 key_tuple = tuple (keys )
1016110221 util .check_arraylike ("lexsort" , * key_tuple )
1016210222 key_arrays = tuple (asarray (k ) for k in key_tuple )
0 commit comments