@@ -310,10 +310,10 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
310310 Wraps `XLA's Gather operator
311311 <https://www.tensorflow.org/xla/operation_semantics#gather>`_.
312312
313- The semantics of gather are complicated, and its API might change in the
314- future. For most use cases , you should prefer `Numpy-style indexing
315- <https:// numpy.org/doc/stable/reference/arrays.indexing.html>`_
316- (e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly .
313+ :func:`gather` is a low-level operator with complicated semantics , and most JAX
314+ users will never need to call it directly. Instead , you should prefer using
315+ `Numpy-style indexing`_, and/or :func:`jax. numpy.ndarray.at`, perhaps in combination
316+ with :func:`jax.vmap` .
317317
318318 Args:
319319 operand: an array from which slices should be taken
@@ -340,6 +340,42 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
340340
341341 Returns:
342342 An array containing the gather output.
343+
344+ Examples:
345+ As mentioned above, you should basically never use :func:`gather` directly,
346+ and instead use NumPy-style indexing expressions to gather values from
347+ arrays.
348+
349+ For example, here is how you can extract values at particular indices using
350+ straightforward indexing semantics, which will lower to XLA's Gather operator:
351+
352+ >>> import jax.numpy as jnp
353+ >>> x = jnp.array([10, 11, 12])
354+ >>> indices = jnp.array([0, 1, 1, 2, 2, 2])
355+
356+ >>> x[indices]
357+ Array([10, 11, 11, 12, 12, 12], dtype=int32)
358+
359+ For control over settings like ``indices_are_sorted``, ``unique_indices``, ``mode``,
360+ and ``fill_value``, you can use the :attr:`jax.numpy.ndarray.at` syntax:
361+
362+ >>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds")
363+ Array([10, 11, 11, 12, 12, 12], dtype=int32)
364+
365+ By comparison, here is the equivalent function call using :func:`gather` directly,
366+ which is not something typical users should ever need to do:
367+
368+ >>> from jax import lax
369+ >>> lax.gather(x, indices[:, None], slice_sizes=(1,),
370+ ... dimension_numbers=lax.GatherDimensionNumbers(
371+ ... offset_dims=(),
372+ ... collapsed_slice_dims=(0,),
373+ ... start_index_map=(0,)),
374+ ... indices_are_sorted=True,
375+ ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
376+ Array([10, 11, 11, 12, 12, 12], dtype=int32)
377+
378+ .. _Numpy-style indexing: https://numpy.org/doc/stable/reference/arrays.indexing.html
343379 """
344380 if mode is None :
345381 mode = GatherScatterMode .PROMISE_IN_BOUNDS
@@ -737,10 +773,9 @@ def scatter(
737773 If multiple updates are performed to the same index of operand, they may be
738774 applied in any order.
739775
740- The semantics of scatter are complicated, and its API might change in the
741- future. For most use cases, you should prefer the
742- :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
743- the familiar NumPy indexing syntax.
776+ :func:`scatter` is a low-level operator with complicated semantics, and most
777+ JAX users will never need to call it directly. Instead, you should prefer using
778+ :func:`jax.numpy.ndarray.at` for more familiary NumPy-style indexing syntax.
744779
745780 Args:
746781 operand: an array to which the scatter should be applied
@@ -764,6 +799,39 @@ def scatter(
764799
765800 Returns:
766801 An array containing the sum of `operand` and the scattered updates.
802+
803+ Examples:
804+ As mentioned above, you should basically never use :func:`scatter` directly,
805+ and instead perform scatter-style operations using NumPy-style indexing
806+ expressions via :attr:`jax.numpy.ndarray.at`.
807+
808+ Here is and example of updating entries in an array using :attr:`jax.numpy.ndarray.at`,
809+ which lowers to an XLA Scatter operation:
810+
811+ >>> x = jnp.zeros(5)
812+ >>> indices = jnp.array([1, 2, 4])
813+ >>> values = jnp.array([2.0, 3.0, 4.0])
814+
815+ >>> x.at[indices].set(values)
816+ Array([0., 2., 3., 0., 4.], dtype=float32)
817+
818+ This syntax also supports several of the optional arguments to :func:`scatter`,
819+ for example:
820+
821+ >>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds')
822+ Array([0., 2., 3., 0., 4.], dtype=float32)
823+
824+ By comparison, here is the equivalent function call using :func:`scatter` directly,
825+ which is not something typical users should ever need to do:
826+
827+ >>> lax.scatter(x, indices[:, None], values,
828+ ... dimension_numbers=lax.ScatterDimensionNumbers(
829+ ... update_window_dims=(),
830+ ... inserted_window_dims=(0,),
831+ ... scatter_dims_to_operand_dims=(0,)),
832+ ... indices_are_sorted=True,
833+ ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
834+ Array([0., 2., 3., 0., 4.], dtype=float32)
767835 """
768836 return scatter_p .bind (
769837 operand , scatter_indices , updates , update_jaxpr = None ,
0 commit comments