@@ -11971,6 +11971,14 @@ def _int(aval):
1197111971
1197211972def _index_to_gather (x_shape : Sequence [int ], idx : Sequence [Any ],
1197311973 normalize_indices : bool = True ) -> _Indexer :
11974+ # Check whether advanced indices are contiguous. We must do this before
11975+ # removing ellipses (https://github.com/jax-ml/jax/issues/25109)
11976+ # If advanced idexing axes do not appear contiguously, NumPy semantics
11977+ # move the advanced axes to the front.
11978+ is_advanced , = np .nonzero ([isinstance (e , (int , Sequence , Array , np .ndarray ))
11979+ or isscalar (e ) for e in idx ])
11980+ advanced_axes_are_contiguous = np .all (np .diff (is_advanced ) == 1 )
11981+
1197411982 # Remove ellipses and add trailing slice(None)s.
1197511983 idx = _canonicalize_tuple_index (len (x_shape ), idx )
1197611984
@@ -11987,10 +11995,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1198711995 # Check for advanced indexing:
1198811996 # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
1198911997
11990- # Do the advanced indexing axes appear contiguously? If not, NumPy semantics
11991- # move the advanced axes to the front.
11992- advanced_axes_are_contiguous = False
11993-
1199411998 advanced_indexes : Sequence [Array | np .ndarray ] | None = None
1199511999
1199612000 # The positions of the advanced indexing axes in `idx`.
@@ -12009,7 +12013,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1200912013 advanced_pairs = ((_normalize_index (e , x_shape [j ]), i , j )
1201012014 for e , i , j in advanced_pairs )
1201112015 advanced_indexes , idx_advanced_axes , x_advanced_axes = zip (* advanced_pairs )
12012- advanced_axes_are_contiguous = bool (np .all (np .diff (idx_advanced_axes ) == 1 ))
1201312016
1201412017 x_axis = 0 # Current axis in x.
1201512018 y_axis = 0 # Current axis in y, before collapsing. See below.
0 commit comments