@@ -271,18 +271,191 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
271271 return pp .text (f"{{{ self } }}" )
272272
273273
274+ @tree_util .register_pytree_node_class
275+ @dataclasses .dataclass (frozen = True )
276+ class RefNewAxis :
277+ """Transform that inserts new axes at specified positions."""
278+ positions : tuple [int , ...] # positions to insert new axes (in output)
279+
280+ def tree_flatten (self ):
281+ return (), (self .positions ,)
282+
283+ @classmethod
284+ def tree_unflatten (cls , metadata , arrays ):
285+ assert not arrays
286+ return cls (* metadata )
287+
288+ def transform_shape (
289+ self , shape : tuple [int | Array , ...] | None
290+ ) -> tuple [int | Array , ...] | None :
291+ if shape is None :
292+ return None
293+ result = list (shape )
294+ for pos in sorted (self .positions ):
295+ result .insert (pos , 1 )
296+ return tuple (result )
297+
298+ def transform_dtype (self , dtype ):
299+ return dtype
300+
301+ def transform_sharding (self , sharding ):
302+ if all (p is None for p in sharding .spec ):
303+ return sharding
304+ raise NotImplementedError
305+
306+ def pretty_print (self , context : core .JaxprPpContext ) -> pp .Doc :
307+ del context
308+ return pp .text (f"{{newaxis{ list (self .positions )} }}" )
309+
310+
311+ @tree_util .register_pytree_node_class
312+ @dataclasses .dataclass (frozen = True )
313+ class RefFlip :
314+ """Transform that flips (reverses) specified axes."""
315+ axes : tuple [int , ...] # axes to flip
316+
317+ def tree_flatten (self ):
318+ return (), (self .axes ,)
319+
320+ @classmethod
321+ def tree_unflatten (cls , metadata , arrays ):
322+ assert not arrays
323+ return cls (* metadata )
324+
325+ def transform_shape (
326+ self , shape : tuple [int | Array , ...] | None
327+ ) -> tuple [int | Array , ...] | None :
328+ # Flip doesn't change shape
329+ return shape
330+
331+ def transform_dtype (self , dtype ):
332+ return dtype
333+
334+ def transform_sharding (self , sharding ):
335+ if all (p is None for p in sharding .spec ):
336+ return sharding
337+ raise NotImplementedError
338+
339+ def pretty_print (self , context : core .JaxprPpContext ) -> pp .Doc :
340+ del context
341+ return pp .text (f"{{flip{ list (self .axes )} }}" )
342+
343+
344+ def _expand_ellipsis (indices : list , shape_len : int ) -> list :
345+ """Expand ellipsis in indices to appropriate number of slice(None)."""
346+ num_none = sum (idx is None for idx in indices )
347+ num_ellipsis = sum (idx is ... for idx in indices )
348+
349+ if num_ellipsis > 0 :
350+ ip = indices .index (...)
351+ num_real_indices = len (indices ) - num_ellipsis - num_none
352+ num_slices_needed = shape_len - num_real_indices
353+ indices [ip :ip + 1 ] = [slice (None )] * max (0 , num_slices_needed )
354+
355+ return indices
356+
357+
358+ def _separate_none_indices (indices : list ) -> tuple [list , list ]:
359+ """Separate None indices and track output positions.
360+
361+ Returns:
362+ (none_positions, filtered_indices)
363+ """
364+ none_positions = []
365+ filtered_indices = []
366+ output_pos = 0
367+
368+ for idx in indices :
369+ if idx is None :
370+ none_positions .append (output_pos )
371+ output_pos += 1
372+ else :
373+ filtered_indices .append (idx )
374+ if isinstance (idx , slice ) or isinstance (idx , indexing .Slice ):
375+ output_pos += 1
376+ elif not isinstance (idx , (int , np .integer )) and hasattr (idx , 'shape' ) and idx .shape :
377+ output_pos += len (idx .shape )
378+
379+ return none_positions , filtered_indices
380+
381+
382+ def _convert_negative_slices (filtered_indices : list , shape : tuple ) -> tuple [list , list ]:
383+ """Convert negative step slices to positive equivalents.
384+
385+ Returns:
386+ (converted_indices, flip_axes)
387+ """
388+ flip_axes = []
389+ converted_indices = []
390+ output_axis = 0
391+
392+ for i , idx in enumerate (filtered_indices ):
393+ if isinstance (idx , slice ):
394+ dim_size = shape [i ] if i < len (shape ) else 1
395+ start , step , size = core .canonicalize_slice (idx , dim_size )
396+
397+ if step < 0 :
398+ if size > 0 :
399+ new_start = start + (size - 1 ) * step
400+ new_step = - step
401+ converted_indices .append (slice (new_start , new_start + size * new_step , new_step ))
402+ flip_axes .append (output_axis )
403+ else :
404+ converted_indices .append (slice (start , start , 1 ))
405+ output_axis += 1
406+ else :
407+ converted_indices .append (idx )
408+ output_axis += 1
409+ elif isinstance (idx , (int , np .integer )):
410+ converted_indices .append (idx )
411+ else :
412+ converted_indices .append (idx )
413+ if hasattr (idx , 'shape' ) and idx .shape :
414+ output_axis += len (idx .shape )
415+
416+ return converted_indices , flip_axes
417+
418+
274419@dataclasses .dataclass
275420class RefIndexer :
276421 ref_or_view : Any
277422
278423 def __getitem__ (self , slc ):
279424 if not isinstance (slc , tuple ):
280425 slc = (slc ,)
281- indexer = indexing .NDIndexer .from_indices_shape (slc , self .ref_or_view .shape )
426+
427+ shape = self .ref_or_view .shape
428+
429+ # Expand ellipsis and process indices
430+ indices = _expand_ellipsis (list (slc ), len (shape ))
431+ none_positions , filtered_indices = _separate_none_indices (indices )
432+ converted_indices , flip_axes = _convert_negative_slices (filtered_indices , shape )
433+
434+ # Create indexer tuple
435+ if converted_indices :
436+ filtered_tuple = tuple (converted_indices )
437+ elif shape == ():
438+ filtered_tuple = ()
439+ else :
440+ filtered_tuple = (slice (None ),) * len (shape )
441+
442+ # Build the result
282443 if isinstance (self .ref_or_view , TransformedRef ):
283- view = self .ref_or_view
284- return TransformedRef (view .ref , (* view .transforms , indexer ))
285- return TransformedRef (self .ref_or_view , (indexer ,))
444+ base_ref = self .ref_or_view .ref
445+ current_transforms = self .ref_or_view .transforms
446+ else :
447+ base_ref = self .ref_or_view
448+ current_transforms = ()
449+
450+ transforms = list (current_transforms )
451+ if filtered_tuple :
452+ indexer = indexing .NDIndexer .from_indices_shape (filtered_tuple , shape )
453+ transforms .append (indexer )
454+ if flip_axes :
455+ transforms .append (RefFlip (tuple (flip_axes )))
456+ if none_positions :
457+ transforms .append (RefNewAxis (tuple (none_positions )))
458+ return TransformedRef (base_ref , tuple (transforms ))
286459
287460
288461@dataclasses .dataclass (frozen = True )
0 commit comments