Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from jax._src.state import indexing
from jax._src.state.primitives import addupdate_p, get_p, swap_p, pin, unpin
from jax._src.state.types import (
AbstractRef, RefBitcaster, RefEffect, RefReshaper, get_ref_aval_from_value,
uninitialized,)
AbstractRef, RefBitcaster, RefEffect, RefFlip, RefNewAxis, RefReshaper,
get_ref_aval_from_value, uninitialized,)
from jax._src.state.utils import bitcast, hoist_consts_to_refs
from jax._src.typing import Array
from jax._src.util import (foreach, safe_map, safe_zip, split_list, unzip2,
Expand Down Expand Up @@ -466,6 +466,11 @@ def transform_array(x, transforms):
result = bitcast(result, transform.dtype)
case RefReshaper():
result = result.reshape(transform.shape)
case RefFlip():
result = lax.rev(result, transform.axes)
case RefNewAxis():
# Insert new axes at specified positions
result = lax.expand_dims(result, sorted(transform.positions))
case _:
raise NotImplementedError(f"Unsupported transform: {transform}")
return result
Expand Down Expand Up @@ -509,9 +514,17 @@ def transform_swap_array(x, transforms, val):
# was indexed into.
intermediates.append(new_val)
case RefBitcaster():
intermediates.append(bitcast(new_val, transform.dtype))
new_val = bitcast(new_val, transform.dtype)
intermediates.append(new_val)
case RefReshaper():
intermediates.append(new_val.reshape(transform.shape))
new_val = new_val.reshape(transform.shape)
intermediates.append(new_val)
case RefFlip():
new_val = lax.rev(new_val, transform.axes)
intermediates.append(new_val)
case RefNewAxis():
new_val = lax.expand_dims(new_val, sorted(transform.positions))
intermediates.append(new_val)
case _:
raise NotImplementedError(f"Unsupported transform: {transform}")

Expand All @@ -520,7 +533,7 @@ def transform_swap_array(x, transforms, val):
new_x = val

# Write phase (reversed loop)
for intermediate, transform in reversed(zip(intermediates[:-1], transforms)):
for intermediate, transform in reversed(list(zip(intermediates[:-1], transforms))):
if isinstance(transform, indexing.NDIndexer):
indexer = transform
if _is_trivial_indexer(indexer):
Expand All @@ -541,6 +554,15 @@ def transform_swap_array(x, transforms, val):
if transpose_order is not None:
transpose_order_inversed = np.argsort(transpose_order)
new_x = new_x.transpose(transpose_order_inversed)
elif isinstance(transform, RefFlip):
# Reverse the flip
new_x = lax.rev(new_x, transform.axes)
elif isinstance(transform, RefNewAxis):
# Squeeze the added axes
new_x = lax.squeeze(new_x, sorted(transform.positions))
elif isinstance(transform, (RefBitcaster, RefReshaper)):
# These are handled implicitly through shape matching
pass
else:
raise NotImplementedError(f"Unsupported transform: {transform}")

Expand Down
15 changes: 10 additions & 5 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
def from_slice(cls, slc: slice, size: int) -> Slice:
start, step, size = core.canonicalize_slice(slc, size)
if step < 1:
raise ValueError(f"slice must have a step >= 1 (found: {step})")
raise ValueError(
f"Slice step must be positive (found: {step}). "
"Negative steps should be handled before reaching NDIndexer."
)
return cls(start, size, step)


Expand Down Expand Up @@ -172,12 +175,14 @@ def __post_init__(self):
if isinstance(idx, Slice):
start = idx.start
if value := _maybe_concretize(start):
if value >= s:
size_val = _maybe_concretize(idx.size)
# Allow start == s when size == 0 (empty slice)
if value >= s and (size_val is None or size_val > 0):
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
if size := _maybe_concretize(idx.size):
if value + (size - 1) * idx.stride >= s:
if size_val:
if value + (size_val - 1) * idx.stride >= s:
raise ValueError(
f"Out of bound slice: start={value}, size={size},"
f"Out of bound slice: start={value}, size={size_val},"
f" stride={idx.stride}, dim={s}."
)
continue
Expand Down
181 changes: 177 additions & 4 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,191 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
return pp.text(f"{{{self}}}")


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class RefNewAxis:
"""Transform that inserts new axes at specified positions."""
positions: tuple[int, ...] # positions to insert new axes (in output)

def tree_flatten(self):
return (), (self.positions,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)

def transform_shape(
self, shape: tuple[int | Array, ...] | None
) -> tuple[int | Array, ...] | None:
if shape is None:
return None
result = list(shape)
for pos in sorted(self.positions):
result.insert(pos, 1)
return tuple(result)

def transform_dtype(self, dtype):
return dtype

def transform_sharding(self, sharding):
if all(p is None for p in sharding.spec):
return sharding
raise NotImplementedError

def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
del context
return pp.text(f"{{newaxis{list(self.positions)}}}")


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class RefFlip:
"""Transform that flips (reverses) specified axes."""
axes: tuple[int, ...] # axes to flip

def tree_flatten(self):
return (), (self.axes,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)

def transform_shape(
self, shape: tuple[int | Array, ...] | None
) -> tuple[int | Array, ...] | None:
# Flip doesn't change shape
return shape

def transform_dtype(self, dtype):
return dtype

def transform_sharding(self, sharding):
if all(p is None for p in sharding.spec):
return sharding
raise NotImplementedError

def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
del context
return pp.text(f"{{flip{list(self.axes)}}}")


def _expand_ellipsis(indices: list, shape_len: int) -> list:
"""Expand ellipsis in indices to appropriate number of slice(None)."""
num_none = sum(idx is None for idx in indices)
num_ellipsis = sum(idx is ... for idx in indices)

if num_ellipsis > 0:
ip = indices.index(...)
num_real_indices = len(indices) - num_ellipsis - num_none
num_slices_needed = shape_len - num_real_indices
indices[ip:ip+1] = [slice(None)] * max(0, num_slices_needed)

return indices


def _separate_none_indices(indices: list) -> tuple[list, list]:
"""Separate None indices and track output positions.

Returns:
(none_positions, filtered_indices)
"""
none_positions = []
filtered_indices = []
output_pos = 0

for idx in indices:
if idx is None:
none_positions.append(output_pos)
output_pos += 1
else:
filtered_indices.append(idx)
if isinstance(idx, slice) or isinstance(idx, indexing.Slice):
output_pos += 1
elif not isinstance(idx, (int, np.integer)) and hasattr(idx, 'shape') and idx.shape:
output_pos += len(idx.shape)

return none_positions, filtered_indices


def _convert_negative_slices(filtered_indices: list, shape: tuple) -> tuple[list, list]:
"""Convert negative step slices to positive equivalents.

Returns:
(converted_indices, flip_axes)
"""
flip_axes = []
converted_indices = []
output_axis = 0

for i, idx in enumerate(filtered_indices):
if isinstance(idx, slice):
dim_size = shape[i] if i < len(shape) else 1
start, step, size = core.canonicalize_slice(idx, dim_size)

if step < 0:
if size > 0:
new_start = start + (size - 1) * step
new_step = -step
converted_indices.append(slice(new_start, new_start + size * new_step, new_step))
flip_axes.append(output_axis)
else:
converted_indices.append(slice(start, start, 1))
output_axis += 1
else:
converted_indices.append(idx)
output_axis += 1
elif isinstance(idx, (int, np.integer)):
converted_indices.append(idx)
else:
converted_indices.append(idx)
if hasattr(idx, 'shape') and idx.shape:
output_axis += len(idx.shape)

return converted_indices, flip_axes


@dataclasses.dataclass
class RefIndexer:
ref_or_view: Any

def __getitem__(self, slc):
if not isinstance(slc, tuple):
slc = (slc,)
indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape)

shape = self.ref_or_view.shape

# Expand ellipsis and process indices
indices = _expand_ellipsis(list(slc), len(shape))
none_positions, filtered_indices = _separate_none_indices(indices)
converted_indices, flip_axes = _convert_negative_slices(filtered_indices, shape)

# Create indexer tuple
if converted_indices:
filtered_tuple = tuple(converted_indices)
elif shape == ():
filtered_tuple = ()
else:
filtered_tuple = (slice(None),) * len(shape)

# Build the result
if isinstance(self.ref_or_view, TransformedRef):
view = self.ref_or_view
return TransformedRef(view.ref, (*view.transforms, indexer))
return TransformedRef(self.ref_or_view, (indexer,))
base_ref = self.ref_or_view.ref
current_transforms = self.ref_or_view.transforms
else:
base_ref = self.ref_or_view
current_transforms = ()

transforms = list(current_transforms)
if filtered_tuple:
indexer = indexing.NDIndexer.from_indices_shape(filtered_tuple, shape)
transforms.append(indexer)
if flip_axes:
transforms.append(RefFlip(tuple(flip_axes)))
if none_positions:
transforms.append(RefNewAxis(tuple(none_positions)))
return TransformedRef(base_ref, tuple(transforms))


@dataclasses.dataclass(frozen=True)
Expand Down