Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,16 @@ class GatherScatterMode(enum.Enum):
performed. In practice, with the current XLA implementation this means
that out-of-bounds gathers will be clamped but out-of-bounds scatters will
be discarded. Gradients will not be correct if indices are out-of-bounds.
BOUNDS_CHECK:
When possible, an error will be raised for out-of-bound indices (for example, when
indices are :term:`static` or when used with :mod:`jax.experimental.checkify`).
Otherwise, behavior is identical to FILL_OR_DROP.
"""
CLIP = enum.auto()
FILL_OR_DROP = enum.auto()
PROMISE_IN_BOUNDS = enum.auto()
ONE_HOT = enum.auto()
BOUNDS_CHECK = enum.auto()

@staticmethod
def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode:
Expand All @@ -324,6 +329,8 @@ def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode:
return GatherScatterMode.PROMISE_IN_BOUNDS
if s == "one_hot":
return GatherScatterMode.ONE_HOT
if s == "bounds_check":
return GatherScatterMode.BOUNDS_CHECK
else:
raise ValueError(f'Unknown gather mode "{s}"')

Expand Down Expand Up @@ -411,7 +418,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
if mode is None:
mode = GatherScatterMode.PROMISE_IN_BOUNDS
parsed_mode = GatherScatterMode.from_any(mode)
if parsed_mode == GatherScatterMode.FILL_OR_DROP:
if parsed_mode in [GatherScatterMode.BOUNDS_CHECK, GatherScatterMode.FILL_OR_DROP]:
if fill_value is None:
dtype = _dtype(operand)
if dtypes.issubdtype(dtype, np.inexact):
Expand Down Expand Up @@ -2382,7 +2389,7 @@ def _gather_lower(ctx, operand, indices, *,
indices_are_sorted=indices_are_sorted, mode=mode,
fill_value=fill_value)]

if mode == GatherScatterMode.FILL_OR_DROP:
if mode in [GatherScatterMode.BOUNDS_CHECK, GatherScatterMode.FILL_OR_DROP]:
gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False)
return gather_fill_fn(
ctx, operand, indices,
Expand Down
28 changes: 25 additions & 3 deletions jax/_src/numpy/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,12 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
# @api.jit(static_argnums=(1, 2))
def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
unique_indices, mode, fill_value, normalize_indices):
parsed_mode = slicing.GatherScatterMode.from_any(mode)
idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = index_to_gather(
indexer = index_to_gather( # shared with _scatter_update
np.shape(arr), idx, core.typeof(arr).sharding,
normalize_indices=normalize_indices) # shared with _scatter_update
normalize_indices=normalize_indices,
raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK))
jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
y = arr

Expand Down Expand Up @@ -790,7 +792,9 @@ def _aval_or_none(x):
return None

def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
x_sharding, normalize_indices: bool = True) -> _Indexer:
x_sharding, *,
normalize_indices: bool = True,
raise_on_oob: bool = False) -> _Indexer:
# Convert sequences to arrays
idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
if isinstance(i, Sequence) else i for i in idx)
Expand Down Expand Up @@ -835,6 +839,24 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
x_shape = tuple(x_shape)
x_spec = tuple(x_spec)

# TODO(jakevdp): for efficiency, we should handle normalize_indices just once here.
# Also, we need to normalize indices statically where possible.

if raise_on_oob:
idx_no_nones = [ind for ind in idx if ind is not None]
assert len(idx_no_nones) == len(x_shape)
def _check_static_index_in_bounds(ind, axis_num):
if not isinstance(ind, (int, np.integer)):
return
user_ind = ind
if normalize_indices:
ind = ind + x_shape[axis_num] if ind < 0 else ind
if not (0 <= ind < x_shape[axis_num]):
raise IndexError(f"index {user_ind} is out of bounds for axis {axis_num}"
f" with size {x_shape[axis_num]}")
for axis_num, ind in enumerate(idx_no_nones):
_check_static_index_in_bounds(ind, axis_num)

# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing

Expand Down
7 changes: 5 additions & 2 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
mode: slicing.GatherScatterMode | str | None, normalize_indices: bool):
dtype = lax.dtype(x)
weak_type = dtypes.is_weakly_typed(x)
parsed_mode = slicing.GatherScatterMode.from_any(mode)

if not dtypes.safe_to_cast(y, x):
# TODO(jakevdp): change this to an error after the deprecation period.
Expand All @@ -113,8 +114,10 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
FutureWarning)

idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = indexing.index_to_gather(np.shape(x), idx, core.typeof(x).sharding,
normalize_indices=normalize_indices)
indexer = indexing.index_to_gather(
np.shape(x), idx, core.typeof(x).sharding,
normalize_indices=normalize_indices,
raise_on_oob=(parsed_mode == slicing.GatherScatterMode.BOUNDS_CHECK))

# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2533,6 +2533,7 @@ def _gather_lowering_rule(
and mode
in (
lax.GatherScatterMode.FILL_OR_DROP,
lax.GatherScatterMode.BOUNDS_CHECK,
lax.GatherScatterMode.PROMISE_IN_BOUNDS,
)
):
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/roofline/rooflines.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def _calculate_gather_flops(
) -> int:
"""Calculates roofline unfused flops for Jax's gather primitive."""

if mode == slicing.GatherScatterMode.FILL_OR_DROP:
if mode in [slicing.GatherScatterMode.BOUNDS_CHECK, slicing.GatherScatterMode.FILL_OR_DROP]:
# With FILL_OR_DROP, we have 4 steps to check whether to fill (or drop):
# 1. Check if the index is within upper bound.
# 2. Check if the index is within lower bound.
Expand Down
35 changes: 35 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,41 @@ def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
x.at[idx].get(mode="fill", fill_value=7),
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))

@parameterized.parameters(
((2, 3), 4, "index 4 is out of bounds for axis 0 with size 2"),
((2, 3), (0, 4), "index 4 is out of bounds for axis 1 with size 3"),
((2, 3), (-1, 4), "index 4 is out of bounds for axis 1 with size 3"),
((2, 3, 5), (..., -10), "index -10 is out of bounds for axis 2 with size 5"),
((3,), (-4, None), "index -4 is out of bounds for axis 0 with size 3"),
)
def testBoundsCheck(self, shape, idx, msg):
x = jnp.zeros(shape)

# Note: in both cases here we avoid passing idx to the function
# in order for it to remain static.
def f_gather(x):
return x.at[idx].get(mode="bounds_check")

def f_scatter(x):
return x.at[idx].set(0.0, mode="bounds_check")

with self.subTest("gather"):
with self.assertRaisesRegex(IndexError, msg):
f_gather(x)

with self.subTest("gather-jit"):
with self.assertRaisesRegex(IndexError, msg):
jax.jit(f_gather)(x)

with self.subTest("scatter"):
with self.assertRaisesRegex(IndexError, msg):
f_scatter(x)

with self.subTest("scatter-jit"):
with self.assertRaisesRegex(IndexError, msg):
jax.jit(f_scatter)(x)


def testIndexingWeakTypes(self):
x = lax_internal._convert_element_type(jnp.arange(5), dtypes.dtype(float),
weak_type=True)
Expand Down
5 changes: 5 additions & 0 deletions tests/roofline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,11 @@ def with_neg(f):
mode=lax.GatherScatterMode.FILL_OR_DROP,
expected_flops=4 * 2 * 1 + 2 * 3,
),
dict(
testcase_name="bounds_check",
mode=lax.GatherScatterMode.BOUNDS_CHECK,
expected_flops=4 * 2 * 1 + 2 * 3,
),
)
def test_gather_roofline(self, mode, expected_flops):
operand = jnp.zeros((3, 3), dtype=jnp.int32)
Expand Down
Loading