diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 13b0323c4d07..11196ca3fc23 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -627,9 +627,9 @@ def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, sl if OOBError not in enabled_errors: return error, out - operand_dims = np.array(operand.shape) - slice_sizes = np.array(slice_sizes) start_indices = jnp.array(start_indices) + operand_dims = np.array(operand.shape, dtype=start_indices.dtype) + slice_sizes = np.array(slice_sizes, dtype=start_indices.dtype) oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims) payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index aa5f5d3a7626..a4c91cef5211 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -606,7 +606,8 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None, # We must be careful with dtypes because dynamic_slice requires all # start indices to have matching types. if len(start_indices) > 1: - start_indices = util.promote_dtypes(*start_indices) + index_dtype = lax_utils.int_dtype_for_shape(arr.shape, signed=True) + start_indices = [lax.convert_element_type(idx, index_dtype) for idx in start_indices] jnp_error._check_precondition_oob_dynamic_slice( arr.shape, start_indices, slice_sizes, allow_negative_indices ) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index a58daa464d84..d9be101ac787 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -1291,6 +1291,14 @@ def testWrongNumberOfIndices(self): "Too many indices: 1-dimensional array indexed with 2 regular indices."): jnp.zeros(3)[:, 5] + @jtu.sample_product(shape=[(), (1,)]) + def testIndexDtypePromotion(self, shape): + # Regression test for https://github.com/jax-ml/jax/issues/31396 + numbers = jnp.arange(1000)[:, None] + idx = jnp.int8(0).reshape(shape) + expected = np.array(999).reshape(shape) + self.assertArraysEqual(numbers[999, idx], expected) + def _broadcastable_shapes(shape): """Returns all shapes that broadcast to `shape`."""