@@ -271,18 +271,110 @@ def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
271271 return pp .text (f"{{{ self } }}" )
272272
273273
274+ @tree_util .register_pytree_node_class
275+ @dataclasses .dataclass (frozen = True )
276+ class RefNewAxis :
277+ """Transform that inserts new axes at specified positions."""
278+ positions : tuple [int , ...] # positions to insert new axes (in output)
279+
280+ def tree_flatten (self ):
281+ return (), (self .positions ,)
282+
283+ @classmethod
284+ def tree_unflatten (cls , metadata , arrays ):
285+ assert not arrays
286+ return cls (* metadata )
287+
288+ def transform_shape (
289+ self , shape : tuple [int | Array , ...] | None
290+ ) -> tuple [int | Array , ...] | None :
291+ if shape is None :
292+ return None
293+ result = list (shape )
294+ for pos in sorted (self .positions ):
295+ result .insert (pos , 1 )
296+ return tuple (result )
297+
298+ def transform_dtype (self , dtype ):
299+ return dtype
300+
301+ def transform_sharding (self , sharding ):
302+ if all (p is None for p in sharding .spec ):
303+ return sharding
304+ raise NotImplementedError
305+
306+ def pretty_print (self , context : core .JaxprPpContext ) -> pp .Doc :
307+ del context
308+ return pp .text (f"{{newaxis{ list (self .positions )} }}" )
309+
310+
274311@dataclasses .dataclass
275312class RefIndexer :
276313 ref_or_view : Any
277314
278315 def __getitem__ (self , slc ):
279316 if not isinstance (slc , tuple ):
280317 slc = (slc ,)
281- indexer = indexing .NDIndexer .from_indices_shape (slc , self .ref_or_view .shape )
318+
319+ # Handle None values (np.newaxis) in indices
320+ none_positions = []
321+ filtered_indices = []
322+ output_pos = 0
323+
324+ # First pass: expand ellipsis if present
325+ indices = list (slc )
326+ num_none = sum (idx is None for idx in indices )
327+ num_ellipsis = sum (idx is ... for idx in indices )
328+
329+ if num_ellipsis > 0 :
330+ # Expand ellipsis accounting for None values
331+ ip = indices .index (...)
332+ num_real_indices = len (indices ) - num_ellipsis - num_none
333+ num_slices_needed = len (self .ref_or_view .shape ) - num_real_indices
334+ indices [ip :ip + 1 ] = [slice (None )] * max (0 , num_slices_needed )
335+
336+ # Second pass: separate None from other indices and track positions
337+ for idx in indices :
338+ if idx is None :
339+ none_positions .append (output_pos )
340+ output_pos += 1
341+ else :
342+ filtered_indices .append (idx )
343+ # Slices and ints both consume one input dim
344+ # but slices produce one output dim, ints produce zero
345+ if isinstance (idx , slice ) or isinstance (idx , indexing .Slice ):
346+ output_pos += 1
347+ elif not isinstance (idx , (int , np .integer )) and hasattr (idx , 'shape' ) and idx .shape :
348+ # Array indexer produces its shape
349+ output_pos += len (idx .shape )
350+ # Scalar int indexers don't add to output dims
351+
352+ # Create indexer without None
353+ filtered_tuple = tuple (filtered_indices ) if filtered_indices else (slice (None ),) * len (self .ref_or_view .shape )
354+ if not filtered_indices and self .ref_or_view .shape == ():
355+ # Special case: scalar ref with only None indexing
356+ filtered_tuple = ()
357+
358+ # Build the result
282359 if isinstance (self .ref_or_view , TransformedRef ):
283360 view = self .ref_or_view
284- return TransformedRef (view .ref , (* view .transforms , indexer ))
285- return TransformedRef (self .ref_or_view , (indexer ,))
361+ if filtered_tuple :
362+ indexer = indexing .NDIndexer .from_indices_shape (filtered_tuple , view .shape )
363+ transforms = (* view .transforms , indexer )
364+ else :
365+ transforms = view .transforms
366+ if none_positions :
367+ transforms = (* transforms , RefNewAxis (tuple (none_positions )))
368+ return TransformedRef (view .ref , transforms )
369+ else :
370+ if filtered_tuple :
371+ indexer = indexing .NDIndexer .from_indices_shape (filtered_tuple , self .ref_or_view .shape )
372+ transforms = (indexer ,)
373+ else :
374+ transforms = ()
375+ if none_positions :
376+ transforms = (* transforms , RefNewAxis (tuple (none_positions )))
377+ return TransformedRef (self .ref_or_view , transforms )
286378
287379
288380@dataclasses .dataclass (frozen = True )
0 commit comments