From 5094008263447548b2cba28a9a0712cd316d476a Mon Sep 17 00:00:00 2001 From: Yashwant Bezawada Date: Thu, 20 Nov 2025 21:47:55 -0600 Subject: [PATCH] Fix Ref indexing discrepancies with array indexing This addresses the issues in #33322 where Ref indexing behaved differently from JAX array indexing: 1. OOB slice clamping: Allow empty slices when start equals dim instead of raising an error 2. None indexing: Add RefNewAxis transform to handle np.newaxis in indices, enabling x[None] and x[..., None] 3. Negative slice steps: Convert negative step slices to positive equivalents and apply RefFlip transform to reverse the result 4. Updated discharge.py to handle the new transforms in both transform_array and transform_swap_array 5. Refactored RefIndexer.__getitem__ to reduce code duplication Fixes #33322. --- jax/_src/state/discharge.py | 32 ++++++- jax/_src/state/indexing.py | 15 ++- jax/_src/state/types.py | 181 +++++++++++++++++++++++++++++++++++- 3 files changed, 214 insertions(+), 14 deletions(-) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index a20d5648217b..db926a4bf105 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -40,8 +40,8 @@ from jax._src.state import indexing from jax._src.state.primitives import addupdate_p, get_p, swap_p, pin, unpin from jax._src.state.types import ( - AbstractRef, RefBitcaster, RefEffect, RefReshaper, get_ref_aval_from_value, - uninitialized,) + AbstractRef, RefBitcaster, RefEffect, RefFlip, RefNewAxis, RefReshaper, + get_ref_aval_from_value, uninitialized,) from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array from jax._src.util import (foreach, safe_map, safe_zip, split_list, unzip2, @@ -466,6 +466,11 @@ def transform_array(x, transforms): result = bitcast(result, transform.dtype) case RefReshaper(): result = result.reshape(transform.shape) + case RefFlip(): + result = lax.rev(result, transform.axes) + case RefNewAxis(): + # Insert new axes at specified positions + result = lax.expand_dims(result, sorted(transform.positions)) case _: raise NotImplementedError(f"Unsupported transform: {transform}") return result @@ -509,9 +514,17 @@ def transform_swap_array(x, transforms, val): # was indexed into. intermediates.append(new_val) case RefBitcaster(): - intermediates.append(bitcast(new_val, transform.dtype)) + new_val = bitcast(new_val, transform.dtype) + intermediates.append(new_val) case RefReshaper(): - intermediates.append(new_val.reshape(transform.shape)) + new_val = new_val.reshape(transform.shape) + intermediates.append(new_val) + case RefFlip(): + new_val = lax.rev(new_val, transform.axes) + intermediates.append(new_val) + case RefNewAxis(): + new_val = lax.expand_dims(new_val, sorted(transform.positions)) + intermediates.append(new_val) case _: raise NotImplementedError(f"Unsupported transform: {transform}") @@ -520,7 +533,7 @@ def transform_swap_array(x, transforms, val): new_x = val # Write phase (reversed loop) - for intermediate, transform in reversed(zip(intermediates[:-1], transforms)): + for intermediate, transform in reversed(list(zip(intermediates[:-1], transforms))): if isinstance(transform, indexing.NDIndexer): indexer = transform if _is_trivial_indexer(indexer): @@ -541,6 +554,15 @@ def transform_swap_array(x, transforms, val): if transpose_order is not None: transpose_order_inversed = np.argsort(transpose_order) new_x = new_x.transpose(transpose_order_inversed) + elif isinstance(transform, RefFlip): + # Reverse the flip + new_x = lax.rev(new_x, transform.axes) + elif isinstance(transform, RefNewAxis): + # Squeeze the added axes + new_x = lax.squeeze(new_x, sorted(transform.positions)) + elif isinstance(transform, (RefBitcaster, RefReshaper)): + # These are handled implicitly through shape matching + pass else: raise NotImplementedError(f"Unsupported transform: {transform}") diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index c2741c91fa7d..0454ba6d06f2 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -75,7 +75,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice: def from_slice(cls, slc: slice, size: int) -> Slice: start, step, size = core.canonicalize_slice(slc, size) if step < 1: - raise ValueError(f"slice must have a step >= 1 (found: {step})") + raise ValueError( + f"Slice step must be positive (found: {step}). " + "Negative steps should be handled before reaching NDIndexer." + ) return cls(start, size, step) @@ -172,12 +175,14 @@ def __post_init__(self): if isinstance(idx, Slice): start = idx.start if value := _maybe_concretize(start): - if value >= s: + size_val = _maybe_concretize(idx.size) + # Allow start == s when size == 0 (empty slice) + if value >= s and (size_val is None or size_val > 0): raise ValueError(f"Out of bound slice: start={value}, dim={s}.") - if size := _maybe_concretize(idx.size): - if value + (size - 1) * idx.stride >= s: + if size_val: + if value + (size_val - 1) * idx.stride >= s: raise ValueError( - f"Out of bound slice: start={value}, size={size}," + f"Out of bound slice: start={value}, size={size_val}," f" stride={idx.stride}, dim={s}." ) continue diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 2644f8392416..44a2da779206 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -271,6 +271,151 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: return pp.text(f"{{{self}}}") +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class RefNewAxis: + """Transform that inserts new axes at specified positions.""" + positions: tuple[int, ...] # positions to insert new axes (in output) + + def tree_flatten(self): + return (), (self.positions,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + def transform_shape( + self, shape: tuple[int | Array, ...] | None + ) -> tuple[int | Array, ...] | None: + if shape is None: + return None + result = list(shape) + for pos in sorted(self.positions): + result.insert(pos, 1) + return tuple(result) + + def transform_dtype(self, dtype): + return dtype + + def transform_sharding(self, sharding): + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context + return pp.text(f"{{newaxis{list(self.positions)}}}") + + +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class RefFlip: + """Transform that flips (reverses) specified axes.""" + axes: tuple[int, ...] # axes to flip + + def tree_flatten(self): + return (), (self.axes,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + def transform_shape( + self, shape: tuple[int | Array, ...] | None + ) -> tuple[int | Array, ...] | None: + # Flip doesn't change shape + return shape + + def transform_dtype(self, dtype): + return dtype + + def transform_sharding(self, sharding): + if all(p is None for p in sharding.spec): + return sharding + raise NotImplementedError + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context + return pp.text(f"{{flip{list(self.axes)}}}") + + +def _expand_ellipsis(indices: list, shape_len: int) -> list: + """Expand ellipsis in indices to appropriate number of slice(None).""" + num_none = sum(idx is None for idx in indices) + num_ellipsis = sum(idx is ... for idx in indices) + + if num_ellipsis > 0: + ip = indices.index(...) + num_real_indices = len(indices) - num_ellipsis - num_none + num_slices_needed = shape_len - num_real_indices + indices[ip:ip+1] = [slice(None)] * max(0, num_slices_needed) + + return indices + + +def _separate_none_indices(indices: list) -> tuple[list, list]: + """Separate None indices and track output positions. + + Returns: + (none_positions, filtered_indices) + """ + none_positions = [] + filtered_indices = [] + output_pos = 0 + + for idx in indices: + if idx is None: + none_positions.append(output_pos) + output_pos += 1 + else: + filtered_indices.append(idx) + if isinstance(idx, slice) or isinstance(idx, indexing.Slice): + output_pos += 1 + elif not isinstance(idx, (int, np.integer)) and hasattr(idx, 'shape') and idx.shape: + output_pos += len(idx.shape) + + return none_positions, filtered_indices + + +def _convert_negative_slices(filtered_indices: list, shape: tuple) -> tuple[list, list]: + """Convert negative step slices to positive equivalents. + + Returns: + (converted_indices, flip_axes) + """ + flip_axes = [] + converted_indices = [] + output_axis = 0 + + for i, idx in enumerate(filtered_indices): + if isinstance(idx, slice): + dim_size = shape[i] if i < len(shape) else 1 + start, step, size = core.canonicalize_slice(idx, dim_size) + + if step < 0: + if size > 0: + new_start = start + (size - 1) * step + new_step = -step + converted_indices.append(slice(new_start, new_start + size * new_step, new_step)) + flip_axes.append(output_axis) + else: + converted_indices.append(slice(start, start, 1)) + output_axis += 1 + else: + converted_indices.append(idx) + output_axis += 1 + elif isinstance(idx, (int, np.integer)): + converted_indices.append(idx) + else: + converted_indices.append(idx) + if hasattr(idx, 'shape') and idx.shape: + output_axis += len(idx.shape) + + return converted_indices, flip_axes + + @dataclasses.dataclass class RefIndexer: ref_or_view: Any @@ -278,11 +423,39 @@ class RefIndexer: def __getitem__(self, slc): if not isinstance(slc, tuple): slc = (slc,) - indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape) + + shape = self.ref_or_view.shape + + # Expand ellipsis and process indices + indices = _expand_ellipsis(list(slc), len(shape)) + none_positions, filtered_indices = _separate_none_indices(indices) + converted_indices, flip_axes = _convert_negative_slices(filtered_indices, shape) + + # Create indexer tuple + if converted_indices: + filtered_tuple = tuple(converted_indices) + elif shape == (): + filtered_tuple = () + else: + filtered_tuple = (slice(None),) * len(shape) + + # Build the result if isinstance(self.ref_or_view, TransformedRef): - view = self.ref_or_view - return TransformedRef(view.ref, (*view.transforms, indexer)) - return TransformedRef(self.ref_or_view, (indexer,)) + base_ref = self.ref_or_view.ref + current_transforms = self.ref_or_view.transforms + else: + base_ref = self.ref_or_view + current_transforms = () + + transforms = list(current_transforms) + if filtered_tuple: + indexer = indexing.NDIndexer.from_indices_shape(filtered_tuple, shape) + transforms.append(indexer) + if flip_axes: + transforms.append(RefFlip(tuple(flip_axes))) + if none_positions: + transforms.append(RefNewAxis(tuple(none_positions))) + return TransformedRef(base_ref, tuple(transforms)) @dataclasses.dataclass(frozen=True)