Skip to content

Commit 8800fe2

Browse files
committed
Better documentation for jnp.lexsort
1 parent 5d3cac6 commit 8800fe2

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10154,9 +10154,69 @@ def sort_complex(a: ArrayLike) -> Array:
1015410154
a = lax.sort(asarray(a))
1015510155
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
1015610156

10157-
@util.implements(np.lexsort)
10157+
1015810158
@partial(jit, static_argnames=('axis',))
1015910159
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
10160+
"""Sort a sequence of keys in lexicographic order.
10161+
10162+
JAX implementation of :func:`numpy.lexsort`.
10163+
10164+
Args:
10165+
keys: a sequence of arrays to sort; all arrays must have the same shape.
10166+
The last key in the sequence is used as the primary key.
10167+
axis: the axis along which to sort (default: -1).
10168+
10169+
Returns:
10170+
An array of integers of shape ``keys[0].shape`` giving the indices of the
10171+
entries in lexicographically-sorted order.
10172+
10173+
See also:
10174+
- :func:`jax.numpy.argsort`: sort a single entry by index.
10175+
- :func:`jax.lax.sort`: direct XLA sorting API.
10176+
10177+
Examples:
10178+
:func:`lexsort` with a single key is equivalent to :func:`argsort`:
10179+
10180+
>>> key1 = jnp.array([4, 2, 3, 2, 5])
10181+
>>> jnp.lexsort([key1])
10182+
Array([1, 3, 2, 0, 4], dtype=int32)
10183+
>>> jnp.argsort(key1)
10184+
Array([1, 3, 2, 0, 4], dtype=int32)
10185+
10186+
With multiple keys, :func:`lexsort` uses the last key as the primary key:
10187+
10188+
>>> key2 = jnp.array([2, 1, 1, 2, 2])
10189+
>>> jnp.lexsort([key1, key2])
10190+
Array([1, 2, 3, 0, 4], dtype=int32)
10191+
10192+
The meaning of the indices become more clear when printing the sorted keys:
10193+
10194+
>>> indices = jnp.lexsort([key1, key2])
10195+
>>> print(f"{key1[indices]}\\n{key2[indices]}")
10196+
[2 3 2 4 5]
10197+
[1 1 2 2 2]
10198+
10199+
Notice that the elements of ``key2`` appear in order, and within the sequences
10200+
of duplicated values the corresponding elements of ```key1`` appear in order.
10201+
10202+
For multi-dimensional inputs, :func:`lexsort` defaults to sorting along the
10203+
last axis:
10204+
10205+
>>> key1 = jnp.array([[2, 4, 2, 3],
10206+
... [3, 1, 2, 2]])
10207+
>>> key2 = jnp.array([[1, 2, 1, 3],
10208+
... [2, 1, 2, 1]])
10209+
>>> jnp.lexsort([key1, key2])
10210+
Array([[0, 2, 1, 3],
10211+
[1, 3, 2, 0]], dtype=int32)
10212+
10213+
A different sort axis can be chosen using the ``axis`` keyword; here we sort
10214+
along the leading axis:
10215+
10216+
>>> jnp.lexsort([key1, key2], axis=0)
10217+
Array([[0, 1, 0, 1],
10218+
[1, 0, 1, 0]], dtype=int32)
10219+
"""
1016010220
key_tuple = tuple(keys)
1016110221
util.check_arraylike("lexsort", *key_tuple)
1016210222
key_arrays = tuple(asarray(k) for k in key_tuple)

0 commit comments

Comments
 (0)