@@ -671,10 +671,12 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
671671# @api.jit(static_argnums=(1, 2))
672672def _gather (arr , dynamic_idx , * , treedef , static_idx , indices_are_sorted ,
673673 unique_indices , mode , fill_value , normalize_indices ):
674+ parsed_mode = slicing .GatherScatterMode .from_any (mode )
674675 idx = merge_static_and_dynamic_indices (treedef , static_idx , dynamic_idx )
675- indexer = index_to_gather (
676+ indexer = index_to_gather ( # shared with _scatter_update
676677 np .shape (arr ), idx , core .typeof (arr ).sharding ,
677- normalize_indices = normalize_indices ) # shared with _scatter_update
678+ normalize_indices = normalize_indices ,
679+ raise_on_oob = (parsed_mode == slicing .GatherScatterMode .BOUNDS_CHECK ))
678680 jnp_error ._check_precondition_oob_gather (arr .shape , indexer .gather_indices )
679681 y = arr
680682
@@ -790,7 +792,9 @@ def _aval_or_none(x):
790792 return None
791793
792794def index_to_gather (x_shape : Sequence [int ], idx : Sequence [Any ],
793- x_sharding , normalize_indices : bool = True ) -> _Indexer :
795+ x_sharding , * ,
796+ normalize_indices : bool = True ,
797+ raise_on_oob : bool = False ) -> _Indexer :
794798 # Convert sequences to arrays
795799 idx = tuple (lax_numpy .asarray (i , dtype = None if i else int )
796800 if isinstance (i , Sequence ) else i for i in idx )
@@ -835,6 +839,24 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
835839 x_shape = tuple (x_shape )
836840 x_spec = tuple (x_spec )
837841
842+ # TODO(jakevdp): for efficiency, we should handle normalize_indices just once here.
843+ # Also, we need to normalize indices statically where possible.
844+
845+ if raise_on_oob :
846+ idx_no_nones = [ind for ind in idx if ind is not None ]
847+ assert len (idx_no_nones ) == len (x_shape )
848+ def _check_static_index_in_bounds (ind , axis_num ):
849+ if not isinstance (ind , (int , np .integer )):
850+ return
851+ user_ind = ind
852+ if normalize_indices :
853+ ind = ind + x_shape [axis_num ] if ind < 0 else ind
854+ if not (0 <= ind < x_shape [axis_num ]):
855+ raise IndexError (f"index { user_ind } is out of bounds for axis { axis_num } "
856+ f" with size { x_shape [axis_num ]} " )
857+ for axis_num , ind in enumerate (idx_no_nones ):
858+ _check_static_index_in_bounds (ind , axis_num )
859+
838860 # Check for advanced indexing:
839861 # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
840862
0 commit comments