Skip to content

Commit 2cee99e

Browse files
committed
Add new BOUNDS_CHECK GatherScatterMode
1 parent a83c167 commit 2cee99e

File tree

7 files changed

+63
-8
lines changed

7 files changed

+63
-8
lines changed

jax/_src/lax/slicing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,16 @@ class GatherScatterMode(enum.Enum):
306306
performed. In practice, with the current XLA implementation this means
307307
that out-of-bounds gathers will be clamped but out-of-bounds scatters will
308308
be discarded. Gradients will not be correct if indices are out-of-bounds.
309+
BOUNDS_CHECK:
310+
When possible, an error will be raised for out-of-bound indices (for example, when
311+
indices are :term:`static` or when used with :mod:`jax.experimental.checkify`).
312+
Otherwise, behavior is identical to FILL_OR_DROP.
309313
"""
310314
CLIP = enum.auto()
311315
FILL_OR_DROP = enum.auto()
312316
PROMISE_IN_BOUNDS = enum.auto()
313317
ONE_HOT = enum.auto()
318+
BOUNDS_CHECK = enum.auto()
314319

315320
@staticmethod
316321
def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode:
@@ -324,6 +329,8 @@ def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode:
324329
return GatherScatterMode.PROMISE_IN_BOUNDS
325330
if s == "one_hot":
326331
return GatherScatterMode.ONE_HOT
332+
if s == "bounds_check":
333+
return GatherScatterMode.BOUNDS_CHECK
327334
else:
328335
raise ValueError(f'Unknown gather mode "{s}"')
329336

@@ -411,7 +418,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
411418
if mode is None:
412419
mode = GatherScatterMode.PROMISE_IN_BOUNDS
413420
parsed_mode = GatherScatterMode.from_any(mode)
414-
if parsed_mode == GatherScatterMode.FILL_OR_DROP:
421+
if parsed_mode in (GatherScatterMode.BOUNDS_CHECK, GatherScatterMode.FILL_OR_DROP):
415422
if fill_value is None:
416423
dtype = _dtype(operand)
417424
if dtypes.issubdtype(dtype, np.inexact):
@@ -2382,7 +2389,7 @@ def _gather_lower(ctx, operand, indices, *,
23822389
indices_are_sorted=indices_are_sorted, mode=mode,
23832390
fill_value=fill_value)]
23842391

2385-
if mode == GatherScatterMode.FILL_OR_DROP:
2392+
if mode in [GatherScatterMode.BOUNDS_CHECK, GatherScatterMode.FILL_OR_DROP]:
23862393
gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False)
23872394
return gather_fill_fn(
23882395
ctx, operand, indices,

jax/_src/numpy/indexing.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,12 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
671671
# @api.jit(static_argnums=(1, 2))
672672
def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
673673
unique_indices, mode, fill_value, normalize_indices):
674+
parsed_mode = slicing.GatherScatterMode.from_any(mode)
674675
idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
675-
indexer = index_to_gather(
676+
indexer = index_to_gather( # shared with _scatter_update
676677
np.shape(arr), idx, core.typeof(arr).sharding,
677-
normalize_indices=normalize_indices) # shared with _scatter_update
678+
normalize_indices=normalize_indices,
679+
raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK))
678680
jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
679681
y = arr
680682

@@ -790,7 +792,9 @@ def _aval_or_none(x):
790792
return None
791793

792794
def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
793-
x_sharding, normalize_indices: bool = True) -> _Indexer:
795+
x_sharding, *,
796+
normalize_indices: bool = True,
797+
raise_on_oob: bool = False) -> _Indexer:
794798
# Convert sequences to arrays
795799
idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
796800
if isinstance(i, Sequence) else i for i in idx)
@@ -835,6 +839,24 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
835839
x_shape = tuple(x_shape)
836840
x_spec = tuple(x_spec)
837841

842+
# TODO(jakevdp): for efficiency, we should handle normalize_indices just once here.
843+
# Also, we need to normalize indices statically where possible.
844+
845+
if raise_on_oob:
846+
idx_no_nones = [ind for ind in idx if ind is not None]
847+
assert len(idx_no_nones) == len(x_shape)
848+
def _check_static_index_in_bounds(ind, axis_num):
849+
if not isinstance(ind, (int, np.integer)):
850+
return
851+
user_ind = ind
852+
if normalize_indices:
853+
ind = ind + x_shape[axis_num] if ind < 0 else ind
854+
if not (0 <= ind < x_shape[axis_num]):
855+
raise IndexError(f"index {user_ind} is out of bounds for axis {axis_num}"
856+
f" with size {x_shape[axis_num]}")
857+
for axis_num, ind in enumerate(idx_no_nones):
858+
_check_static_index_in_bounds(ind, axis_num)
859+
838860
# Check for advanced indexing:
839861
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
840862

jax/_src/ops/scatter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
102102
mode: slicing.GatherScatterMode | str | None, normalize_indices: bool):
103103
dtype = lax.dtype(x)
104104
weak_type = dtypes.is_weakly_typed(x)
105+
parsed_mode = slicing.GatherScatterMode.from_any(mode)
105106

106107
if not dtypes.safe_to_cast(y, x):
107108
# TODO(jakevdp): change this to an error after the deprecation period.
@@ -113,8 +114,10 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
113114
FutureWarning)
114115

115116
idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
116-
indexer = indexing.index_to_gather(np.shape(x), idx, core.typeof(x).sharding,
117-
normalize_indices=normalize_indices)
117+
indexer = indexing.index_to_gather(
118+
np.shape(x), idx, core.typeof(x).sharding,
119+
normalize_indices=normalize_indices,
120+
raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK))
118121

119122
# Avoid calling scatter if the slice shape is empty, both as a fast path and
120123
# to handle cases like zeros(0)[array([], int32)].

jax/_src/pallas/mosaic/lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2533,6 +2533,7 @@ def _gather_lowering_rule(
25332533
and mode
25342534
in (
25352535
lax.GatherScatterMode.FILL_OR_DROP,
2536+
lax.GatherScatterMode.BOUNDS_CHECK,
25362537
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
25372538
)
25382539
):

jax/experimental/roofline/rooflines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def _calculate_gather_flops(
517517
) -> int:
518518
"""Calculates roofline unfused flops for Jax's gather primitive."""
519519

520-
if mode == slicing.GatherScatterMode.FILL_OR_DROP:
520+
if mode in [slicing.GatherScatterMode.BOUNDS_CHECK, slicing.GatherScatterMode.FILL_OR_DROP]:
521521
# With FILL_OR_DROP, we have 4 steps to check whether to fill (or drop):
522522
# 1. Check if the index is within upper bound.
523523
# 2. Check if the index is within lower bound.

tests/lax_numpy_indexing_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,23 @@ def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
12041204
x.at[idx].get(mode="fill", fill_value=7),
12051205
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
12061206

1207+
@parameterized.parameters(
1208+
((2, 3), 4, "index 4 is out of bounds for axis 0 with size 2"),
1209+
((2, 3), (0, 4), "index 4 is out of bounds for axis 1 with size 3"),
1210+
((2, 3), (..., 4), "index 4 is out of bounds for axis 1 with size 3"),
1211+
((2, 3, 5), (..., -10), "index -10 is out of bounds for axis 2 with size 5"),
1212+
((3,), (4, None), "index 4 is out of bounds for axis 0 with size 3"),
1213+
)
1214+
def testBoundsCheck(self, shape, idx, msg):
1215+
x = jnp.zeros(shape)
1216+
with self.subTest("gather"):
1217+
with self.assertRaisesRegex(IndexError, msg):
1218+
x.at[idx].get(mode="bounds_check")
1219+
with self.subTest("scatter"):
1220+
with self.assertRaisesRegex(IndexError, msg):
1221+
x.at[idx].set(0.0, mode="bounds_check")
1222+
1223+
12071224
def testIndexingWeakTypes(self):
12081225
x = lax_internal._convert_element_type(jnp.arange(5), dtypes.dtype(float),
12091226
weak_type=True)

tests/roofline_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,11 @@ def with_neg(f):
919919
mode=lax.GatherScatterMode.FILL_OR_DROP,
920920
expected_flops=4 * 2 * 1 + 2 * 3,
921921
),
922+
dict(
923+
testcase_name="bounds_check",
924+
mode=lax.GatherScatterMode.BOUNDS_CHECK,
925+
expected_flops=4 * 2 * 1 + 2 * 3,
926+
),
922927
)
923928
def test_gather_roofline(self, mode, expected_flops):
924929
operand = jnp.zeros((3, 3), dtype=jnp.int32)

0 commit comments

Comments
 (0)