@@ -671,10 +671,12 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
671671# @api.jit(static_argnums=(1, 2))
672672def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
673673 unique_indices, mode, fill_value, normalize_indices):
674+ parsed_mode = slicing.GatherScatterMode.from_any(mode)
674675 idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
675- indexer = index_to_gather(
676+ indexer = index_to_gather( # shared with _scatter_update
676677 np.shape(arr), idx, core.typeof(arr).sharding,
677- normalize_indices=normalize_indices) # shared with _scatter_update
678+ normalize_indices=normalize_indices,
679+ raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK))
678680 jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
679681 y = arr
680682
@@ -790,7 +792,9 @@ def _aval_or_none(x):
790792 return None
791793
792794def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
793- x_sharding, normalize_indices: bool = True) -> _Indexer:
795+ x_sharding, *,
796+ normalize_indices: bool = True,
797+ raise_on_oob: bool = False) -> _Indexer:
794798 # Convert sequences to arrays
795799 idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
796800 if isinstance(i, Sequence) else i for i in idx)
@@ -835,6 +839,24 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
835839 x_shape = tuple(x_shape)
836840 x_spec = tuple(x_spec)
837841
842+ # TODO(jakevdp): for efficiency, we should handle normalize_indices just once here.
843+ # Also, we need to normalize indices statically where possible.
844+
845+ if raise_on_oob:
846+ idx_no_nones = [ind for ind in idx if ind is not None]
847+ assert len(idx_no_nones) == len(x_shape)
848+ def _check_static_index_in_bounds(ind, axis_num):
849+ if not isinstance(ind, (int, np.integer)):
850+ return
851+ user_ind = ind
852+ if normalize_indices:
853+ ind = ind + x_shape[axis_num] if ind < 0 else ind
854+ if not (0 <= ind < x_shape[axis_num]):
855+ raise IndexError(f"index {user_ind} is out of bounds for axis {axis_num}"
856+ f" with size {x_shape[axis_num]}")
857+ for axis_num, ind in enumerate(idx_no_nones):
858+ _check_static_index_in_bounds(ind, axis_num)
859+
838860 # Check for advanced indexing:
839861 # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
840862
0 commit comments