Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,9 @@ def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, sl
if OOBError not in enabled_errors:
return error, out

operand_dims = np.array(operand.shape)
slice_sizes = np.array(slice_sizes)
start_indices = jnp.array(start_indices)
operand_dims = np.array(operand.shape, dtype=start_indices.dtype)
slice_sizes = np.array(slice_sizes, dtype=start_indices.dtype)
oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)

payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,8 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None,
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
if len(start_indices) > 1:
start_indices = util.promote_dtypes(*start_indices)
index_dtype = lax_utils.int_dtype_for_shape(arr.shape, signed=True)
start_indices = [lax.convert_element_type(idx, index_dtype) for idx in start_indices]
jnp_error._check_precondition_oob_dynamic_slice(
arr.shape, start_indices, slice_sizes, allow_negative_indices
)
Expand Down
8 changes: 8 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,14 @@ def testWrongNumberOfIndices(self):
"Too many indices: 1-dimensional array indexed with 2 regular indices."):
jnp.zeros(3)[:, 5]

@jtu.sample_product(shape=[(), (1,)])
def testIndexDtypePromotion(self, shape):
# Regression test for https://github.com/jax-ml/jax/issues/31396
numbers = jnp.arange(1000)[:, None]
idx = jnp.int8(0).reshape(shape)
expected = np.array(999).reshape(shape)
self.assertArraysEqual(numbers[999, idx], expected)


def _broadcastable_shapes(shape):
"""Returns all shapes that broadcast to `shape`."""
Expand Down
Loading