@@ -751,13 +751,33 @@ def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
751751def _int (aval ):
752752 return not aval .shape and dtypes .issubdtype (aval .dtype , np .integer )
753753
754+ def _aval_or_none (x ):
755+ try :
756+ return core .get_aval (x )
757+ except :
758+ return None
759+
754760def index_to_gather (x_shape : Sequence [int ], idx : Sequence [Any ],
755761 normalize_indices : bool = True ) -> _Indexer :
762+ # Convert sequences to arrays
763+ idx = tuple (lax_numpy .asarray (i , dtype = None if i else int )
764+ if isinstance (i , Sequence ) else i for i in idx )
765+ abstract_idx = [_aval_or_none (i ) for i in idx ]
766+ float_indices = [(i , val , aval ) for i , (val , aval ) in enumerate (zip (idx , abstract_idx ))
767+ if aval is not None and dtypes .issubdtype (aval , np .inexact )]
768+
769+ # Check for float or complex indices:
770+ if float_indices :
771+ i , val , aval = float_indices [0 ]
772+ msg = ("Indexer must have integer or boolean type, got indexer "
773+ "with type {} at position {}, indexer value {}" )
774+ raise TypeError (msg .format (aval .dtype .name , i , val ))
775+
756776 # Check whether advanced indices are contiguous. We must do this before
757777 # removing ellipses (https://github.com/jax-ml/jax/issues/25109)
758778 # If advanced idexing axes do not appear contiguously, NumPy semantics
759779 # move the advanced axes to the front.
760- is_advanced , = np .nonzero ([isinstance (e , (int , Sequence , Array , np .ndarray ))
780+ is_advanced , = np .nonzero ([isinstance (e , (int , np . integer , Array , np .ndarray ))
761781 or lax_numpy .isscalar (e ) for e in idx ])
762782 advanced_axes_are_contiguous = np .all (np .diff (is_advanced ) == 1 )
763783
@@ -862,11 +882,8 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
862882 gather_slice_shape .append (1 )
863883 continue
864884
865- try :
866- abstract_i = core .get_aval (i )
867- except TypeError :
868- abstract_i = None
869885 # Handle basic int indexes.
886+ abstract_i = _aval_or_none (i )
870887 if isinstance (abstract_i , core .ShapedArray ) and _int (abstract_i ):
871888 if core .definitely_equal (x_shape [x_axis ], 0 ):
872889 # XLA gives error when indexing into an axis of size 0
0 commit comments