Skip to content

Commit 56eea2b

Browse files
Merge pull request jax-ml#24312 from jakevdp:gather-doc
PiperOrigin-RevId: 686372450
2 parents 66c6292 + 284ca8b commit 56eea2b

File tree

1 file changed

+76
-8
lines changed

1 file changed

+76
-8
lines changed

jax/_src/lax/slicing.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
310310
Wraps `XLA's Gather operator
311311
<https://www.tensorflow.org/xla/operation_semantics#gather>`_.
312312
313-
The semantics of gather are complicated, and its API might change in the
314-
future. For most use cases, you should prefer `Numpy-style indexing
315-
<https://numpy.org/doc/stable/reference/arrays.indexing.html>`_
316-
(e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly.
313+
:func:`gather` is a low-level operator with complicated semantics, and most JAX
314+
users will never need to call it directly. Instead, you should prefer using
315+
`Numpy-style indexing`_, and/or :func:`jax.numpy.ndarray.at`, perhaps in combination
316+
with :func:`jax.vmap`.
317317
318318
Args:
319319
operand: an array from which slices should be taken
@@ -340,6 +340,42 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
340340
341341
Returns:
342342
An array containing the gather output.
343+
344+
Examples:
345+
As mentioned above, you should basically never use :func:`gather` directly,
346+
and instead use NumPy-style indexing expressions to gather values from
347+
arrays.
348+
349+
For example, here is how you can extract values at particular indices using
350+
straightforward indexing semantics, which will lower to XLA's Gather operator:
351+
352+
>>> import jax.numpy as jnp
353+
>>> x = jnp.array([10, 11, 12])
354+
>>> indices = jnp.array([0, 1, 1, 2, 2, 2])
355+
356+
>>> x[indices]
357+
Array([10, 11, 11, 12, 12, 12], dtype=int32)
358+
359+
For control over settings like ``indices_are_sorted``, ``unique_indices``, ``mode``,
360+
and ``fill_value``, you can use the :attr:`jax.numpy.ndarray.at` syntax:
361+
362+
>>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds")
363+
Array([10, 11, 11, 12, 12, 12], dtype=int32)
364+
365+
By comparison, here is the equivalent function call using :func:`gather` directly,
366+
which is not something typical users should ever need to do:
367+
368+
>>> from jax import lax
369+
>>> lax.gather(x, indices[:, None], slice_sizes=(1,),
370+
... dimension_numbers=lax.GatherDimensionNumbers(
371+
... offset_dims=(),
372+
... collapsed_slice_dims=(0,),
373+
... start_index_map=(0,)),
374+
... indices_are_sorted=True,
375+
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
376+
Array([10, 11, 11, 12, 12, 12], dtype=int32)
377+
378+
.. _Numpy-style indexing: https://numpy.org/doc/stable/reference/arrays.indexing.html
343379
"""
344380
if mode is None:
345381
mode = GatherScatterMode.PROMISE_IN_BOUNDS
@@ -737,10 +773,9 @@ def scatter(
737773
If multiple updates are performed to the same index of operand, they may be
738774
applied in any order.
739775
740-
The semantics of scatter are complicated, and its API might change in the
741-
future. For most use cases, you should prefer the
742-
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
743-
the familiar NumPy indexing syntax.
776+
:func:`scatter` is a low-level operator with complicated semantics, and most
777+
JAX users will never need to call it directly. Instead, you should prefer using
778+
:func:`jax.numpy.ndarray.at` for more familiary NumPy-style indexing syntax.
744779
745780
Args:
746781
operand: an array to which the scatter should be applied
@@ -764,6 +799,39 @@ def scatter(
764799
765800
Returns:
766801
An array containing the sum of `operand` and the scattered updates.
802+
803+
Examples:
804+
As mentioned above, you should basically never use :func:`scatter` directly,
805+
and instead perform scatter-style operations using NumPy-style indexing
806+
expressions via :attr:`jax.numpy.ndarray.at`.
807+
808+
Here is and example of updating entries in an array using :attr:`jax.numpy.ndarray.at`,
809+
which lowers to an XLA Scatter operation:
810+
811+
>>> x = jnp.zeros(5)
812+
>>> indices = jnp.array([1, 2, 4])
813+
>>> values = jnp.array([2.0, 3.0, 4.0])
814+
815+
>>> x.at[indices].set(values)
816+
Array([0., 2., 3., 0., 4.], dtype=float32)
817+
818+
This syntax also supports several of the optional arguments to :func:`scatter`,
819+
for example:
820+
821+
>>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds')
822+
Array([0., 2., 3., 0., 4.], dtype=float32)
823+
824+
By comparison, here is the equivalent function call using :func:`scatter` directly,
825+
which is not something typical users should ever need to do:
826+
827+
>>> lax.scatter(x, indices[:, None], values,
828+
... dimension_numbers=lax.ScatterDimensionNumbers(
829+
... update_window_dims=(),
830+
... inserted_window_dims=(0,),
831+
... scatter_dims_to_operand_dims=(0,)),
832+
... indices_are_sorted=True,
833+
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
834+
Array([0., 2., 3., 0., 4.], dtype=float32)
767835
"""
768836
return scatter_p.bind(
769837
operand, scatter_indices, updates, update_jaxpr=None,

0 commit comments

Comments
 (0)