Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
expand_dims,
join,
moveaxis,
specify_shape,
Expand All @@ -10,6 +11,7 @@
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
Expand Down Expand Up @@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):

@register_lower_xtensor
@node_rewriter([Squeeze])
def local_squeeze_reshape(fgraph, node):
def lower_squeeze(fgraph, node):
"""Rewrite Squeeze to tensor.squeeze."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
Expand All @@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def lower_expand_dims(fgraph, node):
"""Rewrite ExpandDims using tensor operations."""
x, size = node.inputs
out = node.outputs[0]

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)

# Get the new dimension name and position
new_axis = 0 # Always insert at front

# Use tensor operations
if out.type.shape[0] == 1:
# Simple case: just expand with size 1
result_tensor = expand_dims(x_tensor, new_axis)
else:
# Otherwise broadcast to the requested size
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))

# Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape)

# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]
113 changes: 113 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from types import EllipsisType
from typing import Literal

import numpy as np

from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor

Expand Down Expand Up @@ -380,3 +384,112 @@ def squeeze(x, dim=None):
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)


class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""

__props__ = ("dim",)

def __init__(self, dim):
if not isinstance(dim, str):
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")

self.dim = dim

def make_node(self, x, size):
x = as_xtensor(x)

if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")

size = as_xtensor(size, dims=())
if not (size.dtype in integer_dtypes and size.ndim == 0):
raise ValueError(f"size should be an integer scalar, got {size.type}")
try:
static_size = int(get_scalar_constant_value(size))
except NotScalarConstantError:
static_size = None
# If size is a constant, validate it
if static_size is not None and static_size < 0:
raise ValueError(f"size must be 0 or positive, got: {static_size}")
new_shape = (static_size, *x.type.shape)

# Insert new dim at front
new_dims = (self.dim, *x.type.dims)

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x, size], [out])


def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
x = as_xtensor(x)

# Extract size from dim_kwargs if present
size = dim_kwargs.pop("size", 1) if dim_kwargs else 1

# xarray compatibility: error if a sequence (list/tuple) of dims and size are given
if (isinstance(dim, list | tuple)) and ("size" in locals() and size != 1):
raise ValueError("cannot specify both keyword and positional arguments")

if dim is None:
dim = dim_kwargs
elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")

# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: size}
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim}
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str):
dims_dict[name] = len(val)
elif isinstance(val, int):
dims_dict[name] = val
else:
dims_dict[name] = val # symbolic/int scalar allowed
else:
raise TypeError(f"Invalid type for `dim`: {type(dim)}")

# Convert to canonical form: list of (dim_name, size)
canonical_dims: list[tuple[str, int | np.integer | TensorVariable]] = []
for name, size in dims_dict.items():
canonical_dims.append((name, size))

# Store original dimensions for later use with axis
original_dims = list(x.type.dims)

# Insert each new dim at the front (reverse order preserves user intent)
for name, size in reversed(canonical_dims):
x = ExpandDims(dim=name)(x, size)

# If axis is specified, transpose to put new dimensions in the right place
if axis is not None:
new_dim_names = [name for name, _ in canonical_dims]
# Wrap non-sequence axis in a list
if not isinstance(axis, Sequence):
axis = [axis]

# xarray requires len(axis) == len(new_dim_names)
if len(axis) != len(new_dim_names):
raise ValueError("lengths of dim and axis should be identical.")

# Insert each new dim at the specified axis position
# Start with original dims, then insert each new dim at its axis
target_dims = list(original_dims)
# axis values are relative to the result after each insertion
for insert_dim, insert_axis in sorted(
zip(new_dim_names, axis), key=lambda x: x[1]
):
target_dims.insert(insert_axis, insert_dim)
x = Transpose(dims=tuple(target_dims))(x)

return x
41 changes: 41 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,47 @@ def squeeze(
raise NotImplementedError("Squeeze with axis not Implemented")
return px.shape.squeeze(self, dim)

def expand_dims(
self,
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
create_index_for_new_dim: bool = True,
axis: int | None = None,
**dim_kwargs,
):
"""Add one or more new dimensions to the tensor.

Parameters
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
axis : int | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.

Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
return px.shape.expand_dims(
self,
dim,
create_index_for_new_dim=create_index_for_new_dim,
axis=axis,
**dim_kwargs,
)

# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
Expand Down
106 changes: 105 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from itertools import chain, combinations

import numpy as np
import pytest
from xarray import DataArray
from xarray import concat as xr_concat

from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
squeeze,
Expand Down Expand Up @@ -369,3 +369,107 @@ def test_squeeze_errors():
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)


def test_expand_dims():
"""Test expand_dims."""
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
x_test = xr_arange_like(x)

# Implicit size=1
y = x.expand_dims("country")
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))

# Explicit size=1
y = x.expand_dims("country", size=1)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))

# Explicit size > 1
y = x.expand_dims("country", size=4)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 4}))

# Test with multiple dimensions
y = x.expand_dims(["country", "state"])
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))

# Test with a dict of sizes
y = x.expand_dims({"country": 2, "state": 3})
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))

# Test with kwargs (equivalent to dict)
y = x.expand_dims(country=2, state=3)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))

# Symbolic size=1
size_sym_1 = scalar("size_sym_1", dtype="int64")
y = x.expand_dims("country", size=size_sym_1)
fn = xr_function([x, size_sym_1], y)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("country"))

# Test behavior with symbolic size > 1
# NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size.
# This differs from xarray's behavior where expand_dims always adds a size-1 dimension.
size_sym_4 = scalar("size_sym_4", dtype="int64")
y = x.expand_dims("country", size=size_sym_4)
fn = xr_function([x, size_sym_4], y)
res = fn(x_test, 4)
# Our current behavior: broadcasts to size 4
expected = x_test.expand_dims({"country": 4})
xr_assert_allclose(res, expected)
# xarray's behavior would be:
# expected = x_test.expand_dims("country") # always size 1
# xr_assert_allclose(res, expected)

# Test with symbolic sizes in dict
size_sym_1 = scalar("size_sym_1", dtype="int64")
size_sym_2 = scalar("size_sym_2", dtype="int64")
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))

# Test with symbolic sizes in kwargs
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
fn = xr_function([x, size_sym_1, size_sym_2], y)
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))

# Test with axis parameter
y = x.expand_dims("country", axis=1)
fn = xr_function([x], y)
xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1))

# Add two new dims at axis=[1, 2]
y = x.expand_dims(["country", "state"], axis=[1, 2])
fn = xr_function([x], y)
xr_assert_allclose(
fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2])
)


def test_expand_dims_errors():
"""Test error handling in expand_dims."""

# Expanding existing dim
x = xtensor("x", dims=("city",), shape=(3,))
y = x.expand_dims("country")
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("city")

# Invalid dim type
with pytest.raises(TypeError, match="Invalid type for `dim`"):
x.expand_dims(123)

# Duplicate dimension creation
y = x.expand_dims("new")
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("new")

# Test for error when both positional and size are given
with pytest.raises(
ValueError, match="cannot specify both keyword and positional arguments"
):
x.expand_dims(["country", "state"], size=3)