Skip to content

Commit f6f4ef0

Browse files
committed
Fix indexing corner case with empty ellipses
1 parent 40122f7 commit f6f4ef0

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11971,6 +11971,14 @@ def _int(aval):
1197111971

1197211972
def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1197311973
normalize_indices: bool = True) -> _Indexer:
11974+
# Check whether advanced indices are contiguous. We must do this before
11975+
# removing ellipses (https://github.com/jax-ml/jax/issues/25109)
11976+
# If advanced idexing axes do not appear contiguously, NumPy semantics
11977+
# move the advanced axes to the front.
11978+
is_advanced, = np.nonzero([isinstance(e, (int, Sequence, Array, np.ndarray))
11979+
or isscalar(e) for e in idx])
11980+
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
11981+
1197411982
# Remove ellipses and add trailing slice(None)s.
1197511983
idx = _canonicalize_tuple_index(len(x_shape), idx)
1197611984

@@ -11987,10 +11995,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1198711995
# Check for advanced indexing:
1198811996
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
1198911997

11990-
# Do the advanced indexing axes appear contiguously? If not, NumPy semantics
11991-
# move the advanced axes to the front.
11992-
advanced_axes_are_contiguous = False
11993-
1199411998
advanced_indexes: Sequence[Array | np.ndarray] | None = None
1199511999

1199612000
# The positions of the advanced indexing axes in `idx`.
@@ -12009,7 +12013,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1200912013
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
1201012014
for e, i, j in advanced_pairs)
1201112015
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
12012-
advanced_axes_are_contiguous = bool(np.all(np.diff(idx_advanced_axes) == 1))
1201312016

1201412017
x_axis = 0 # Current axis in x.
1201512018
y_axis = 0 # Current axis in y, before collapsing. See below.

tests/lax_numpy_indexing_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,14 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
399399
IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)),
400400
out_shape=(3,)),
401401
]),
402+
("EllipsisWithArrayIndices", [
403+
IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 1]), ..., np.array([0, 1])),
404+
out_shape=(2, 4)),
405+
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), np.array([0, 1]), ..., np.array([0, 1])),
406+
out_shape=(2, 3)),
407+
IndexSpec(shape=(3, 4, 5), indexer=(slice(None), ..., np.array([0, 1]), np.array([0, 1])),
408+
out_shape=(3, 2)),
409+
]),
402410
]
403411

404412

0 commit comments

Comments
 (0)