diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index f1f5ffbc1d7d..c35dd1782f01 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -77,12 +77,14 @@ def take( fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans. - unique_indices: If True, the implementation will assume that the indices are unique, - which can result in more efficient execution on some backends. If set to True and - indices are not unique, the output is undefined. + unique_indices: If True, the implementation will assume that the indices are unique + after normalization of negative indices, which lets the compiler emit more efficient + code during the backward pass. If set to True and normalized indices are not unique, + the result is implementation-defined and may be non-deterministic. indices_are_sorted : If True, the implementation will assume that the indices are - sorted in ascending order, which can lead to more efficient execution on some - backends. If set to True and indices are not sorted, the output is undefined. + sorted in ascending order after normalization of negative indices, which can lead + to more efficient execution on some backends. If set to True and normalized indices + are not sorted, the output is implementation-defined. Returns: Array of values extracted from ``a``.