Skip to content

Commit 5094008

Browse files
Fix Ref indexing discrepancies with array indexing
This addresses the issues in #33322 where Ref indexing behaved differently from JAX array indexing: 1. OOB slice clamping: Allow empty slices when start equals dim instead of raising an error 2. None indexing: Add RefNewAxis transform to handle np.newaxis in indices, enabling x[None] and x[..., None] 3. Negative slice steps: Convert negative step slices to positive equivalents and apply RefFlip transform to reverse the result 4. Updated discharge.py to handle the new transforms in both transform_array and transform_swap_array 5. Refactored RefIndexer.__getitem__ to reduce code duplication Fixes #33322.
1 parent 00fe425 commit 5094008

File tree

3 files changed

+214
-14
lines changed

3 files changed

+214
-14
lines changed

jax/_src/state/discharge.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from jax._src.state import indexing
4141
from jax._src.state.primitives import addupdate_p, get_p, swap_p, pin, unpin
4242
from jax._src.state.types import (
43-
AbstractRef, RefBitcaster, RefEffect, RefReshaper, get_ref_aval_from_value,
44-
uninitialized,)
43+
AbstractRef, RefBitcaster, RefEffect, RefFlip, RefNewAxis, RefReshaper,
44+
get_ref_aval_from_value, uninitialized,)
4545
from jax._src.state.utils import bitcast, hoist_consts_to_refs
4646
from jax._src.typing import Array
4747
from jax._src.util import (foreach, safe_map, safe_zip, split_list, unzip2,
@@ -466,6 +466,11 @@ def transform_array(x, transforms):
466466
result = bitcast(result, transform.dtype)
467467
case RefReshaper():
468468
result = result.reshape(transform.shape)
469+
case RefFlip():
470+
result = lax.rev(result, transform.axes)
471+
case RefNewAxis():
472+
# Insert new axes at specified positions
473+
result = lax.expand_dims(result, sorted(transform.positions))
469474
case _:
470475
raise NotImplementedError(f"Unsupported transform: {transform}")
471476
return result
@@ -509,9 +514,17 @@ def transform_swap_array(x, transforms, val):
509514
# was indexed into.
510515
intermediates.append(new_val)
511516
case RefBitcaster():
512-
intermediates.append(bitcast(new_val, transform.dtype))
517+
new_val = bitcast(new_val, transform.dtype)
518+
intermediates.append(new_val)
513519
case RefReshaper():
514-
intermediates.append(new_val.reshape(transform.shape))
520+
new_val = new_val.reshape(transform.shape)
521+
intermediates.append(new_val)
522+
case RefFlip():
523+
new_val = lax.rev(new_val, transform.axes)
524+
intermediates.append(new_val)
525+
case RefNewAxis():
526+
new_val = lax.expand_dims(new_val, sorted(transform.positions))
527+
intermediates.append(new_val)
515528
case _:
516529
raise NotImplementedError(f"Unsupported transform: {transform}")
517530

@@ -520,7 +533,7 @@ def transform_swap_array(x, transforms, val):
520533
new_x = val
521534

522535
# Write phase (reversed loop)
523-
for intermediate, transform in reversed(zip(intermediates[:-1], transforms)):
536+
for intermediate, transform in reversed(list(zip(intermediates[:-1], transforms))):
524537
if isinstance(transform, indexing.NDIndexer):
525538
indexer = transform
526539
if _is_trivial_indexer(indexer):
@@ -541,6 +554,15 @@ def transform_swap_array(x, transforms, val):
541554
if transpose_order is not None:
542555
transpose_order_inversed = np.argsort(transpose_order)
543556
new_x = new_x.transpose(transpose_order_inversed)
557+
elif isinstance(transform, RefFlip):
558+
# Reverse the flip
559+
new_x = lax.rev(new_x, transform.axes)
560+
elif isinstance(transform, RefNewAxis):
561+
# Squeeze the added axes
562+
new_x = lax.squeeze(new_x, sorted(transform.positions))
563+
elif isinstance(transform, (RefBitcaster, RefReshaper)):
564+
# These are handled implicitly through shape matching
565+
pass
544566
else:
545567
raise NotImplementedError(f"Unsupported transform: {transform}")
546568

jax/_src/state/indexing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
7575
def from_slice(cls, slc: slice, size: int) -> Slice:
7676
start, step, size = core.canonicalize_slice(slc, size)
7777
if step < 1:
78-
raise ValueError(f"slice must have a step >= 1 (found: {step})")
78+
raise ValueError(
79+
f"Slice step must be positive (found: {step}). "
80+
"Negative steps should be handled before reaching NDIndexer."
81+
)
7982
return cls(start, size, step)
8083

8184

@@ -172,12 +175,14 @@ def __post_init__(self):
172175
if isinstance(idx, Slice):
173176
start = idx.start
174177
if value := _maybe_concretize(start):
175-
if value >= s:
178+
size_val = _maybe_concretize(idx.size)
179+
# Allow start == s when size == 0 (empty slice)
180+
if value >= s and (size_val is None or size_val > 0):
176181
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
177-
if size := _maybe_concretize(idx.size):
178-
if value + (size - 1) * idx.stride >= s:
182+
if size_val:
183+
if value + (size_val - 1) * idx.stride >= s:
179184
raise ValueError(
180-
f"Out of bound slice: start={value}, size={size},"
185+
f"Out of bound slice: start={value}, size={size_val},"
181186
f" stride={idx.stride}, dim={s}."
182187
)
183188
continue

jax/_src/state/types.py

Lines changed: 177 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,18 +271,191 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
271271
return pp.text(f"{{{self}}}")
272272

273273

274+
@tree_util.register_pytree_node_class
275+
@dataclasses.dataclass(frozen=True)
276+
class RefNewAxis:
277+
"""Transform that inserts new axes at specified positions."""
278+
positions: tuple[int, ...] # positions to insert new axes (in output)
279+
280+
def tree_flatten(self):
281+
return (), (self.positions,)
282+
283+
@classmethod
284+
def tree_unflatten(cls, metadata, arrays):
285+
assert not arrays
286+
return cls(*metadata)
287+
288+
def transform_shape(
289+
self, shape: tuple[int | Array, ...] | None
290+
) -> tuple[int | Array, ...] | None:
291+
if shape is None:
292+
return None
293+
result = list(shape)
294+
for pos in sorted(self.positions):
295+
result.insert(pos, 1)
296+
return tuple(result)
297+
298+
def transform_dtype(self, dtype):
299+
return dtype
300+
301+
def transform_sharding(self, sharding):
302+
if all(p is None for p in sharding.spec):
303+
return sharding
304+
raise NotImplementedError
305+
306+
def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
307+
del context
308+
return pp.text(f"{{newaxis{list(self.positions)}}}")
309+
310+
311+
@tree_util.register_pytree_node_class
312+
@dataclasses.dataclass(frozen=True)
313+
class RefFlip:
314+
"""Transform that flips (reverses) specified axes."""
315+
axes: tuple[int, ...] # axes to flip
316+
317+
def tree_flatten(self):
318+
return (), (self.axes,)
319+
320+
@classmethod
321+
def tree_unflatten(cls, metadata, arrays):
322+
assert not arrays
323+
return cls(*metadata)
324+
325+
def transform_shape(
326+
self, shape: tuple[int | Array, ...] | None
327+
) -> tuple[int | Array, ...] | None:
328+
# Flip doesn't change shape
329+
return shape
330+
331+
def transform_dtype(self, dtype):
332+
return dtype
333+
334+
def transform_sharding(self, sharding):
335+
if all(p is None for p in sharding.spec):
336+
return sharding
337+
raise NotImplementedError
338+
339+
def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
340+
del context
341+
return pp.text(f"{{flip{list(self.axes)}}}")
342+
343+
344+
def _expand_ellipsis(indices: list, shape_len: int) -> list:
345+
"""Expand ellipsis in indices to appropriate number of slice(None)."""
346+
num_none = sum(idx is None for idx in indices)
347+
num_ellipsis = sum(idx is ... for idx in indices)
348+
349+
if num_ellipsis > 0:
350+
ip = indices.index(...)
351+
num_real_indices = len(indices) - num_ellipsis - num_none
352+
num_slices_needed = shape_len - num_real_indices
353+
indices[ip:ip+1] = [slice(None)] * max(0, num_slices_needed)
354+
355+
return indices
356+
357+
358+
def _separate_none_indices(indices: list) -> tuple[list, list]:
359+
"""Separate None indices and track output positions.
360+
361+
Returns:
362+
(none_positions, filtered_indices)
363+
"""
364+
none_positions = []
365+
filtered_indices = []
366+
output_pos = 0
367+
368+
for idx in indices:
369+
if idx is None:
370+
none_positions.append(output_pos)
371+
output_pos += 1
372+
else:
373+
filtered_indices.append(idx)
374+
if isinstance(idx, slice) or isinstance(idx, indexing.Slice):
375+
output_pos += 1
376+
elif not isinstance(idx, (int, np.integer)) and hasattr(idx, 'shape') and idx.shape:
377+
output_pos += len(idx.shape)
378+
379+
return none_positions, filtered_indices
380+
381+
382+
def _convert_negative_slices(filtered_indices: list, shape: tuple) -> tuple[list, list]:
383+
"""Convert negative step slices to positive equivalents.
384+
385+
Returns:
386+
(converted_indices, flip_axes)
387+
"""
388+
flip_axes = []
389+
converted_indices = []
390+
output_axis = 0
391+
392+
for i, idx in enumerate(filtered_indices):
393+
if isinstance(idx, slice):
394+
dim_size = shape[i] if i < len(shape) else 1
395+
start, step, size = core.canonicalize_slice(idx, dim_size)
396+
397+
if step < 0:
398+
if size > 0:
399+
new_start = start + (size - 1) * step
400+
new_step = -step
401+
converted_indices.append(slice(new_start, new_start + size * new_step, new_step))
402+
flip_axes.append(output_axis)
403+
else:
404+
converted_indices.append(slice(start, start, 1))
405+
output_axis += 1
406+
else:
407+
converted_indices.append(idx)
408+
output_axis += 1
409+
elif isinstance(idx, (int, np.integer)):
410+
converted_indices.append(idx)
411+
else:
412+
converted_indices.append(idx)
413+
if hasattr(idx, 'shape') and idx.shape:
414+
output_axis += len(idx.shape)
415+
416+
return converted_indices, flip_axes
417+
418+
274419
@dataclasses.dataclass
275420
class RefIndexer:
276421
ref_or_view: Any
277422

278423
def __getitem__(self, slc):
279424
if not isinstance(slc, tuple):
280425
slc = (slc,)
281-
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)
426+
427+
shape = self.ref_or_view.shape
428+
429+
# Expand ellipsis and process indices
430+
indices = _expand_ellipsis(list(slc), len(shape))
431+
none_positions, filtered_indices = _separate_none_indices(indices)
432+
converted_indices, flip_axes = _convert_negative_slices(filtered_indices, shape)
433+
434+
# Create indexer tuple
435+
if converted_indices:
436+
filtered_tuple = tuple(converted_indices)
437+
elif shape == ():
438+
filtered_tuple = ()
439+
else:
440+
filtered_tuple = (slice(None),) * len(shape)
441+
442+
# Build the result
282443
if isinstance(self.ref_or_view, TransformedRef):
283-
view = self.ref_or_view
284-
return TransformedRef(view.ref, (*view.transforms, indexer))
285-
return TransformedRef(self.ref_or_view, (indexer,))
444+
base_ref = self.ref_or_view.ref
445+
current_transforms = self.ref_or_view.transforms
446+
else:
447+
base_ref = self.ref_or_view
448+
current_transforms = ()
449+
450+
transforms = list(current_transforms)
451+
if filtered_tuple:
452+
indexer = indexing.NDIndexer.from_indices_shape(filtered_tuple, shape)
453+
transforms.append(indexer)
454+
if flip_axes:
455+
transforms.append(RefFlip(tuple(flip_axes)))
456+
if none_positions:
457+
transforms.append(RefNewAxis(tuple(none_positions)))
458+
return TransformedRef(base_ref, tuple(transforms))
286459

287460

288461
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)