Skip to content

Commit b2c45b8

Browse files
committed
Improved errors when indexing with floats
1 parent 70024d2 commit b2c45b8

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

jax/_src/numpy/indexing.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,13 +751,33 @@ def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
751751
def _int(aval):
752752
return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer)
753753

754+
def _aval_or_none(x):
755+
try:
756+
return core.get_aval(x)
757+
except:
758+
return None
759+
754760
def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
755761
normalize_indices: bool = True) -> _Indexer:
762+
# Convert sequences to arrays
763+
idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
764+
if isinstance(i, Sequence) else i for i in idx)
765+
abstract_idx = [_aval_or_none(i) for i in idx]
766+
float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx))
767+
if aval is not None and dtypes.issubdtype(aval, np.inexact)]
768+
769+
# Check for float or complex indices:
770+
if float_indices:
771+
i, val, aval = float_indices[0]
772+
msg = ("Indexer must have integer or boolean type, got indexer "
773+
"with type {} at position {}, indexer value {}")
774+
raise TypeError(msg.format(aval.dtype.name, i, val))
775+
756776
# Check whether advanced indices are contiguous. We must do this before
757777
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
758778
# If advanced idexing axes do not appear contiguously, NumPy semantics
759779
# move the advanced axes to the front.
760-
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
780+
is_advanced, = np.nonzero([isinstance(e, (int, np.integer, Array, np.ndarray))
761781
or lax_numpy.isscalar(e) for e in idx])
762782
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
763783

@@ -862,11 +882,8 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
862882
gather_slice_shape.append(1)
863883
continue
864884

865-
try:
866-
abstract_i = core.get_aval(i)
867-
except TypeError:
868-
abstract_i = None
869885
# Handle basic int indexes.
886+
abstract_i = _aval_or_none(i)
870887
if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i):
871888
if core.definitely_equal(x_shape[x_axis], 0):
872889
# XLA gives error when indexing into an axis of size 0

tests/lax_numpy_indexing_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,10 @@ def testFloatIndexingError(self):
11201120
jnp.zeros(2).at[0.].add(1.)
11211121
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
11221122
jnp.zeros(2).at[0.].set(1.)
1123+
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
1124+
jnp.zeros((2, 2))[jnp.arange(2), 1.0]
1125+
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
1126+
jnp.zeros((2, 2))[jnp.arange(2), 1 + 1j]
11231127

11241128
def testStrIndexingError(self):
11251129
msg = "JAX does not support string indexing"

0 commit comments

Comments
 (0)