@@ -441,7 +441,7 @@ def make_node(self, x, size):
441441 return Apply (self , [x , size ], [out ])
442442
443443
444- def expand_dims (x , dim = None , create_index_for_new_dim = True , ** dim_kwargs ):
444+ def expand_dims (x , dim = None , create_index_for_new_dim = True , axis = None , ** dim_kwargs ):
445445 """Add one or more new dimensions to an XTensorVariable."""
446446 x = as_xtensor (x )
447447
@@ -479,8 +479,32 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs):
479479 for name , size in dims_dict .items ():
480480 canonical_dims .append ((name , size ))
481481
482+ # Store original dimensions for later use with axis
483+ original_dims = list (x .type .dims )
484+
482485 # Insert each new dim at the front (reverse order preserves user intent)
483486 for name , size in reversed (canonical_dims ):
484487 x = ExpandDims (dim = name )(x , size )
485488
489+ # If axis is specified, transpose to put new dimensions in the right place
490+ if axis is not None :
491+ new_dim_names = [name for name , _ in canonical_dims ]
492+ # Wrap non-sequence axis in a list
493+ if not isinstance (axis , Sequence ):
494+ axis = [axis ]
495+
496+ # xarray requires len(axis) == len(new_dim_names)
497+ if len (axis ) != len (new_dim_names ):
498+ raise ValueError ("lengths of dim and axis should be identical." )
499+
500+ # Insert each new dim at the specified axis position
501+ # Start with original dims, then insert each new dim at its axis
502+ target_dims = list (original_dims )
503+ # axis values are relative to the result after each insertion
504+ for insert_dim , insert_axis in sorted (
505+ zip (new_dim_names , axis ), key = lambda x : x [1 ]
506+ ):
507+ target_dims .insert (insert_axis , insert_dim )
508+ x = transpose (x , * target_dims )
509+
486510 return x
0 commit comments