Skip to content

Commit 0fc6707

Browse files
Fix Ref indexing discrepancies with array indexing
This commit addresses issues where Ref indexing behaved differently from JAX array indexing: 1. OOB slice clamping: Allow start == dim when size == 0 (empty slice) instead of raising an error 2. None indexing support: Add RefNewAxis transform to handle np.newaxis (None) in indices, enabling patterns like x[None] and x[..., None] 3. Improved error message for negative slice steps to clarify this is a known limitation of Ref indexing Fixes parts of issue #33322.
1 parent 00fe425 commit 0fc6707

File tree

2 files changed

+105
-8
lines changed

2 files changed

+105
-8
lines changed

jax/_src/state/indexing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
7575
def from_slice(cls, slc: slice, size: int) -> Slice:
7676
start, step, size = core.canonicalize_slice(slc, size)
7777
if step < 1:
78-
raise ValueError(f"slice must have a step >= 1 (found: {step})")
78+
raise ValueError(
79+
f"Ref indexing does not support negative slice steps (found: {step}). "
80+
"Consider using positive steps and reversing the result if needed."
81+
)
7982
return cls(start, size, step)
8083

8184

@@ -172,12 +175,14 @@ def __post_init__(self):
172175
if isinstance(idx, Slice):
173176
start = idx.start
174177
if value := _maybe_concretize(start):
175-
if value >= s:
178+
size_val = _maybe_concretize(idx.size)
179+
# Allow start == s when size == 0 (empty slice)
180+
if value >= s and (size_val is None or size_val > 0):
176181
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
177-
if size := _maybe_concretize(idx.size):
178-
if value + (size - 1) * idx.stride >= s:
182+
if size_val:
183+
if value + (size_val - 1) * idx.stride >= s:
179184
raise ValueError(
180-
f"Out of bound slice: start={value}, size={size},"
185+
f"Out of bound slice: start={value}, size={size_val},"
181186
f" stride={idx.stride}, dim={s}."
182187
)
183188
continue

jax/_src/state/types.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,18 +271,110 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
271271
return pp.text(f"{{{self}}}")
272272

273273

274+
@tree_util.register_pytree_node_class
275+
@dataclasses.dataclass(frozen=True)
276+
class RefNewAxis:
277+
"""Transform that inserts new axes at specified positions."""
278+
positions: tuple[int, ...] # positions to insert new axes (in output)
279+
280+
def tree_flatten(self):
281+
return (), (self.positions,)
282+
283+
@classmethod
284+
def tree_unflatten(cls, metadata, arrays):
285+
assert not arrays
286+
return cls(*metadata)
287+
288+
def transform_shape(
289+
self, shape: tuple[int | Array, ...] | None
290+
) -> tuple[int | Array, ...] | None:
291+
if shape is None:
292+
return None
293+
result = list(shape)
294+
for pos in sorted(self.positions):
295+
result.insert(pos, 1)
296+
return tuple(result)
297+
298+
def transform_dtype(self, dtype):
299+
return dtype
300+
301+
def transform_sharding(self, sharding):
302+
if all(p is None for p in sharding.spec):
303+
return sharding
304+
raise NotImplementedError
305+
306+
def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
307+
del context
308+
return pp.text(f"{{newaxis{list(self.positions)}}}")
309+
310+
274311
@dataclasses.dataclass
275312
class RefIndexer:
276313
ref_or_view: Any
277314

278315
def __getitem__(self, slc):
279316
if not isinstance(slc, tuple):
280317
slc = (slc,)
281-
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
318+
319+
# Handle None values (np.newaxis) in indices
320+
none_positions = []
321+
filtered_indices = []
322+
output_pos = 0
323+
324+
# First pass: expand ellipsis if present
325+
indices = list(slc)
326+
num_none = sum(idx is None for idx in indices)
327+
num_ellipsis = sum(idx is ... for idx in indices)
328+
329+
if num_ellipsis > 0:
330+
# Expand ellipsis accounting for None values
331+
ip = indices.index(...)
332+
num_real_indices = len(indices) - num_ellipsis - num_none
333+
num_slices_needed = len(self.ref_or_view.shape) - num_real_indices
334+
indices[ip:ip+1] = [slice(None)] * max(0, num_slices_needed)
335+
336+
# Second pass: separate None from other indices and track positions
337+
for idx in indices:
338+
if idx is None:
339+
none_positions.append(output_pos)
340+
output_pos += 1
341+
else:
342+
filtered_indices.append(idx)
343+
# Slices and ints both consume one input dim
344+
# but slices produce one output dim, ints produce zero
345+
if isinstance(idx, slice) or isinstance(idx, indexing.Slice):
346+
output_pos += 1
347+
elif not isinstance(idx, (int, np.integer)) and hasattr(idx, 'shape') and idx.shape:
348+
# Array indexer produces its shape
349+
output_pos += len(idx.shape)
350+
# Scalar int indexers don't add to output dims
351+
352+
# Create indexer without None
353+
filtered_tuple = tuple(filtered_indices) if filtered_indices else (slice(None),) * len(self.ref_or_view.shape)
354+
if not filtered_indices and self.ref_or_view.shape == ():
355+
# Special case: scalar ref with only None indexing
356+
filtered_tuple = ()
357+
358+
# Build the result
282359
if isinstance(self.ref_or_view, TransformedRef):
283360
view = self.ref_or_view
284-
return TransformedRef(view.ref, (*view.transforms, indexer))
285-
return TransformedRef(self.ref_or_view, (indexer,))
361+
if filtered_tuple:
362+
indexer = indexing.NDIndexer.from_indices_shape(filtered_tuple, view.shape)
363+
transforms = (*view.transforms, indexer)
364+
else:
365+
transforms = view.transforms
366+
if none_positions:
367+
transforms = (*transforms, RefNewAxis(tuple(none_positions)))
368+
return TransformedRef(view.ref, transforms)
369+
else:
370+
if filtered_tuple:
371+
indexer = indexing.NDIndexer.from_indices_shape(filtered_tuple, self.ref_or_view.shape)
372+
transforms = (indexer,)
373+
else:
374+
transforms = ()
375+
if none_positions:
376+
transforms = (*transforms, RefNewAxis(tuple(none_positions)))
377+
return TransformedRef(self.ref_or_view, transforms)
286378

287379

288380
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)