diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 85c328d88f1c..9d5023e8b4dc 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -306,11 +306,16 @@ class GatherScatterMode(enum.Enum): performed. In practice, with the current XLA implementation this means that out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds. + BOUNDS_CHECK: + When possible, an error will be raised for out-of-bound indices (for example, when + indices are :term:`static` or when used with :mod:`jax.experimental.checkify`). + Otherwise, behavior is identical to FILL_OR_DROP. """ CLIP = enum.auto() FILL_OR_DROP = enum.auto() PROMISE_IN_BOUNDS = enum.auto() ONE_HOT = enum.auto() + BOUNDS_CHECK = enum.auto() @staticmethod def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode: @@ -324,6 +329,8 @@ def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode: return GatherScatterMode.PROMISE_IN_BOUNDS if s == "one_hot": return GatherScatterMode.ONE_HOT + if s == "bounds_check": + return GatherScatterMode.BOUNDS_CHECK else: raise ValueError(f'Unknown gather mode "{s}"') @@ -411,7 +418,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, if mode is None: mode = GatherScatterMode.PROMISE_IN_BOUNDS parsed_mode = GatherScatterMode.from_any(mode) - if parsed_mode == GatherScatterMode.FILL_OR_DROP: + if parsed_mode in [GatherScatterMode.BOUNDS_CHECK, GatherScatterMode.FILL_OR_DROP]: if fill_value is None: dtype = _dtype(operand) if dtypes.issubdtype(dtype, np.inexact): @@ -2382,7 +2389,7 @@ def _gather_lower(ctx, operand, indices, *, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)] - if mode == GatherScatterMode.FILL_OR_DROP: + if mode in [GatherScatterMode.BOUNDS_CHECK, GatherScatterMode.FILL_OR_DROP]: gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False) return gather_fill_fn( ctx, operand, indices, diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index f1f5ffbc1d7d..8885038b31f8 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -671,10 +671,12 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, # @api.jit(static_argnums=(1, 2)) def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices): + parsed_mode = slicing.GatherScatterMode.from_any(mode) idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = index_to_gather( + indexer = index_to_gather( # shared with _scatter_update np.shape(arr), idx, core.typeof(arr).sharding, - normalize_indices=normalize_indices) # shared with _scatter_update + normalize_indices=normalize_indices, + raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK)) jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr @@ -790,7 +792,9 @@ def _aval_or_none(x): return None def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], - x_sharding, normalize_indices: bool = True) -> _Indexer: + x_sharding, *, + normalize_indices: bool = True, + raise_on_oob: bool = False) -> _Indexer: # Convert sequences to arrays idx = tuple(lax_numpy.asarray(i, dtype=None if i else int) if isinstance(i, Sequence) else i for i in idx) @@ -835,6 +839,24 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], x_shape = tuple(x_shape) x_spec = tuple(x_spec) + # TODO(jakevdp): for efficiency, we should handle normalize_indices just once here. + # Also, we need to normalize indices statically where possible. + + if raise_on_oob: + idx_no_nones = [ind for ind in idx if ind is not None] + assert len(idx_no_nones) == len(x_shape) + def _check_static_index_in_bounds(ind, axis_num): + if not isinstance(ind, (int, np.integer)): + return + user_ind = ind + if normalize_indices: + ind = ind + x_shape[axis_num] if ind < 0 else ind + if not (0 <= ind < x_shape[axis_num]): + raise IndexError(f"index {user_ind} is out of bounds for axis {axis_num}" + f" with size {x_shape[axis_num]}") + for axis_num, ind in enumerate(idx_no_nones): + _check_static_index_in_bounds(ind, axis_num) + # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 739d4019b71f..e6e3bba7a72a 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -102,6 +102,7 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, mode: slicing.GatherScatterMode | str | None, normalize_indices: bool): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) + parsed_mode = slicing.GatherScatterMode.from_any(mode) if not dtypes.safe_to_cast(y, x): # TODO(jakevdp): change this to an error after the deprecation period. @@ -113,8 +114,10 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, FutureWarning) idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = indexing.index_to_gather(np.shape(x), idx, core.typeof(x).sharding, - normalize_indices=normalize_indices) + indexer = indexing.index_to_gather( + np.shape(x), idx, core.typeof(x).sharding, + normalize_indices=normalize_indices, + raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK)) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index cbe5e85322c7..7e88a7254c6e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2533,6 +2533,7 @@ def _gather_lowering_rule( and mode in ( lax.GatherScatterMode.FILL_OR_DROP, + lax.GatherScatterMode.BOUNDS_CHECK, lax.GatherScatterMode.PROMISE_IN_BOUNDS, ) ): diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 87572c42181b..f1eec7268e8f 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -517,7 +517,7 @@ def _calculate_gather_flops( ) -> int: """Calculates roofline unfused flops for Jax's gather primitive.""" - if mode == slicing.GatherScatterMode.FILL_OR_DROP: + if mode in [slicing.GatherScatterMode.BOUNDS_CHECK, slicing.GatherScatterMode.FILL_OR_DROP]: # With FILL_OR_DROP, we have 4 steps to check whether to fill (or drop): # 1. Check if the index is within upper bound. # 2. Check if the index is within lower bound. diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index d9be101ac787..bbc48fa8dfd6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1204,6 +1204,41 @@ def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x.at[idx].get(mode="fill", fill_value=7), jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32)) + @parameterized.parameters( + ((2, 3), 4, "index 4 is out of bounds for axis 0 with size 2"), + ((2, 3), (0, 4), "index 4 is out of bounds for axis 1 with size 3"), + ((2, 3), (-1, 4), "index 4 is out of bounds for axis 1 with size 3"), + ((2, 3, 5), (..., -10), "index -10 is out of bounds for axis 2 with size 5"), + ((3,), (-4, None), "index -4 is out of bounds for axis 0 with size 3"), + ) + def testBoundsCheck(self, shape, idx, msg): + x = jnp.zeros(shape) + + # Note: in both cases here we avoid passing idx to the function + # in order for it to remain static. + def f_gather(x): + return x.at[idx].get(mode="bounds_check") + + def f_scatter(x): + return x.at[idx].set(0.0, mode="bounds_check") + + with self.subTest("gather"): + with self.assertRaisesRegex(IndexError, msg): + f_gather(x) + + with self.subTest("gather-jit"): + with self.assertRaisesRegex(IndexError, msg): + jax.jit(f_gather)(x) + + with self.subTest("scatter"): + with self.assertRaisesRegex(IndexError, msg): + f_scatter(x) + + with self.subTest("scatter-jit"): + with self.assertRaisesRegex(IndexError, msg): + jax.jit(f_scatter)(x) + + def testIndexingWeakTypes(self): x = lax_internal._convert_element_type(jnp.arange(5), dtypes.dtype(float), weak_type=True) diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 5480406aa104..545243a88a83 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -919,6 +919,11 @@ def with_neg(f): mode=lax.GatherScatterMode.FILL_OR_DROP, expected_flops=4 * 2 * 1 + 2 * 3, ), + dict( + testcase_name="bounds_check", + mode=lax.GatherScatterMode.BOUNDS_CHECK, + expected_flops=4 * 2 * 1 + 2 * 3, + ), ) def test_gather_roofline(self, mode, expected_flops): operand = jnp.zeros((3, 3), dtype=jnp.int32)