Skip to content

Commit 30354cd

Browse files
committed
Adding axis parameter
1 parent 38d87c6 commit 30354cd

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

pytensor/xtensor/shape.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def make_node(self, x, size):
441441
return Apply(self, [x, size], [out])
442442

443443

444-
def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs):
444+
def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwargs):
445445
"""Add one or more new dimensions to an XTensorVariable."""
446446
x = as_xtensor(x)
447447

@@ -479,8 +479,32 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs):
479479
for name, size in dims_dict.items():
480480
canonical_dims.append((name, size))
481481

482+
# Store original dimensions for later use with axis
483+
original_dims = list(x.type.dims)
484+
482485
# Insert each new dim at the front (reverse order preserves user intent)
483486
for name, size in reversed(canonical_dims):
484487
x = ExpandDims(dim=name)(x, size)
485488

489+
# If axis is specified, transpose to put new dimensions in the right place
490+
if axis is not None:
491+
new_dim_names = [name for name, _ in canonical_dims]
492+
# Wrap non-sequence axis in a list
493+
if not isinstance(axis, Sequence):
494+
axis = [axis]
495+
496+
# xarray requires len(axis) == len(new_dim_names)
497+
if len(axis) != len(new_dim_names):
498+
raise ValueError("lengths of dim and axis should be identical.")
499+
500+
# Insert each new dim at the specified axis position
501+
# Start with original dims, then insert each new dim at its axis
502+
target_dims = list(original_dims)
503+
# axis values are relative to the result after each insertion
504+
for insert_dim, insert_axis in sorted(
505+
zip(new_dim_names, axis), key=lambda x: x[1]
506+
):
507+
target_dims.insert(insert_axis, insert_dim)
508+
x = transpose(x, *target_dims)
509+
486510
return x

pytensor/xtensor/type.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def expand_dims(
485485
self,
486486
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
487487
create_index_for_new_dim: bool = True,
488+
axis: int | None = None,
488489
**dim_kwargs,
489490
):
490491
"""Add one or more new dimensions to the tensor.
@@ -497,7 +498,14 @@ def expand_dims(
497498
- int: the new size
498499
- sequence: coordinates (length determines size)
499500
create_index_for_new_dim : bool, default: True
500-
(Ignored for now) Matches xarray API, reserved for future use.
501+
Currently ignored. Reserved for future coordinate support.
502+
In xarray, when True (default), creates a coordinate index for the new dimension
503+
with values from 0 to size-1. When False, no coordinate index is created.
504+
axis : int | None, default: None
505+
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
506+
By default (None), new dimensions are inserted at the beginning (axis=0).
507+
Symbolic axis is not supported yet.
508+
Negative values count from the end.
501509
**dim_kwargs : int | Sequence
502510
Alternative to `dim` dict. Only used if `dim` is None.
503511
@@ -507,7 +515,11 @@ def expand_dims(
507515
A tensor with additional dimensions inserted at the front.
508516
"""
509517
return px.shape.expand_dims(
510-
self, dim, create_index_for_new_dim=create_index_for_new_dim, **dim_kwargs
518+
self,
519+
dim,
520+
create_index_for_new_dim=create_index_for_new_dim,
521+
axis=axis,
522+
**dim_kwargs,
511523
)
512524

513525
# ndarray methods

tests/xtensor/test_shape.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,18 @@ def test_expand_dims():
437437
fn = xr_function([x, size_sym_1, size_sym_2], y)
438438
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
439439

440+
# Test with axis parameter
441+
y = x.expand_dims("country", axis=1)
442+
fn = xr_function([x], y)
443+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))
444+
445+
# Add two new dims at axis=[1, 2]
446+
y = x.expand_dims(["country", "state"], axis=[1, 2])
447+
fn = xr_function([x], y)
448+
xr_assert_allclose(
449+
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
450+
)
451+
440452

441453
def test_expand_dims_errors():
442454
"""Test error handling in expand_dims."""

0 commit comments

Comments
 (0)