Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import numpy as np

from pytensor.graph import node_rewriter
from pytensor.raise_op import Assert
from pytensor.tensor import (
broadcast_to,
expand_dims,
gt,
join,
moveaxis,
specify_shape,
Expand All @@ -10,6 +15,7 @@
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
Expand Down Expand Up @@ -132,3 +138,29 @@ def local_squeeze_reshape(fgraph, node):

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def local_expand_dims_reshape(fgraph, node):
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape."""
[x] = node.inputs
x_tensor = tensor_from_xtensor(x)
x_tensor_expanded = expand_dims(x_tensor, axis=0)

target_shape = node.outputs[0].type.shape

size = getattr(node.op, "size", 1)
if isinstance(size, int | np.integer):
if size != 1 and None not in target_shape:
x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape)
else:
# Symbolic size: enforce shape so broadcast happens downstream correctly
# Also validate that size is positive
x_tensor_expanded = Assert(msg="size must be positive")(
x_tensor_expanded, gt(size, 0)
)
x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape)

new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims)
return [new_out]
74 changes: 74 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import EllipsisType
from typing import Literal

import numpy as np

from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
Expand Down Expand Up @@ -380,3 +382,75 @@ def squeeze(x, dim=None):
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)


class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""

__props__ = ("dims", "size")

def __init__(self, dim, size=1):
self.dims = dim
self.size = size

def make_node(self, x):
x = as_xtensor(x)

if self.dims is None:
# No-op: return same variable
return Apply(self, [x], [x])

# Insert new dim at front
new_dims = (self.dims, *x.type.dims)

# Determine shape
if isinstance(self.size, int | np.integer):
new_shape = (self.size, *x.type.shape)
else:
new_shape = (None, *x.type.shape) # symbolic size

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x], [out])


def expand_dims(x, dim: str | None, size=1):
"""Add a new dimension to an XTensorVariable.

Parameters
----------
x : XTensorVariable
Input tensor
dim : str or None
Name of new dimension. If None, returns x unchanged.
size : int or symbolic, optional
Size of the new dimension (default 1)

Returns
-------
XTensorVariable
Tensor with the new dimension inserted
"""
x = as_xtensor(x)

if dim is None:
return x # No-op

if not isinstance(dim, str):
raise TypeError(f"`dim` must be a string or None, got: {type(dim)}")

if dim in x.type.dims:
raise ValueError(f"Dimension {dim} already exists in {x.type.dims}")

if isinstance(size, int | np.integer):
if size <= 0:
raise ValueError(f"size must be positive, got: {size}")
elif not (
hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar
):
raise TypeError(f"size must be an int or scalar variable, got: {type(size)}")

return ExpandDims(dim=dim, size=size)(x)
209 changes: 208 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from itertools import chain, combinations

import numpy as np
import pytest
import xarray as xr
from xarray import DataArray
from xarray import concat as xr_concat

from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
expand_dims,
squeeze,
stack,
transpose,
Expand Down Expand Up @@ -301,6 +303,15 @@ def test_squeeze_explicit_dims():
fn3d = xr_function([x3], y3d)
xr_assert_allclose(fn3d(x3_test), x3_test)

# Reversibility with expand_dims
x6 = xtensor("x6", dims=("a", "b", "c"), shape=(2, 1, 3))
y6 = squeeze(x6, "b")
# First expand_dims adds at front, then transpose puts it in the right place
z6 = transpose(expand_dims(y6, "b"), "a", "b", "c")
fn6 = xr_function([x6], z6)
x6_test = xr_arange_like(x6)
xr_assert_allclose(fn6(x6_test), x6_test)


def test_squeeze_implicit_dims():
"""Test squeeze with implicit dim=None (all size-1 dimensions)."""
Expand Down Expand Up @@ -369,3 +380,199 @@ def test_squeeze_errors():
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)


def test_expand_dims_explicit():
"""Test expand_dims with explicitly named dimensions and sizes."""

# 1D case
x = xtensor("x", dims=("city",), shape=(3,))
y = expand_dims(x, "country")
fn = xr_function([x], y)
x_xr = xr_arange_like(x)
xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country"))

# 2D case
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
y = expand_dims(x, "country")
fn = xr_function([x], y)
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))

# 3D case
x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2))
y = expand_dims(x, "country")
fn = xr_function([x], y)
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))

# Prepending various dims
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
for new_dim in ("x", "y", "z"):
y = expand_dims(x, new_dim)
assert y.type.dims == (new_dim, "a", "b")
assert y.type.shape == (1, 2, 3)

# Explicit size=1 behaves like default
y1 = expand_dims(x, "batch", size=1)
y2 = expand_dims(x, "batch")
fn1 = xr_function([x], y1)
fn2 = xr_function([x], y2)
x_test = xr_arange_like(x)
xr_assert_allclose(fn1(x_test), fn2(x_test))

# Scalar expansion
x = xtensor("x", dims=(), shape=())
y = expand_dims(x, "batch")
assert y.type.dims == ("batch",)
assert y.type.shape == (1,)
fn = xr_function([x], y)
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch"))

# Static size > 1: broadcast
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=4)
fn = xr_function([x], y)
expected = xr.DataArray(
np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)),
dims=("batch", "a", "b"),
coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]},
)
xr_assert_allclose(fn(xr_arange_like(x)), expected)

# Insert new dim between existing dims
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "new")
# Insert new dim between a and b: ("a", "new", "b")
y = transpose(y, "a", "new", "b")
fn = xr_function([x], y)
x_test = xr_arange_like(x)
expected = x_test.expand_dims("new").transpose("a", "new", "b")
xr_assert_allclose(fn(x_test), expected)

# Expand with multiple dims
x = xtensor("x", dims=(), shape=())
y = expand_dims(expand_dims(x, "a"), "b")
fn = xr_function([x], y)
expected = xr_arange_like(x).expand_dims("a").expand_dims("b")
xr_assert_allclose(fn(xr_arange_like(x)), expected)


def test_expand_dims_implicit():
"""Test expand_dims with default or symbolic sizes and dim=None."""

# Symbolic size=1: same as default
size_sym_1 = scalar("size_sym_1", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym_1)
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))

# Symbolic size > 1 (but expand only adds dim=1)
size_sym_4 = scalar("size_sym_4", dtype="int64")
y = expand_dims(x, "batch", size=size_sym_4)
fn = xr_function([x, size_sym_4], y, on_unused_input="ignore")
xr_assert_allclose(fn(x_test, 4), x_test.expand_dims("batch"))

# Symbolic size > 1 with broadcasting
size_sym_4 = scalar("size_sym_4", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym_4)
z = y + y # This should broadcast along the batch dimension
fn = xr_function([x, size_sym_4], z, on_unused_input="ignore")
x_test = xr_arange_like(x)
out = fn(x_test, 4)
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
xr_assert_allclose(out, expected)

# Symbolic size with shape validation
size_sym = scalar("size_sym", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym)
z = y + y # This should validate the shape
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
x_test = xr_arange_like(x)
out = fn(x_test, 4)
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
xr_assert_allclose(out, expected)

# Symbolic size with subsequent operations
size_sym = scalar("size_sym", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym)
z = y.sum("batch") # This should work with symbolic size
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
x_test = xr_arange_like(x)
out = fn(x_test, 4)
expected = x_test.expand_dims("batch").sum("batch")
xr_assert_allclose(out, expected)

# Symbolic size with transpose and broadcasting
size_sym = scalar("size_sym", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym)
z = transpose(y, "batch", "a", "b") # This should work with symbolic size
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
x_test = xr_arange_like(x)
out = fn(x_test, 4)
expected = x_test.expand_dims("batch").transpose("batch", "a", "b")
xr_assert_allclose(out, expected)

# Reversibility: expand then squeeze
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch")
z = squeeze(y, "batch")
fn = xr_function([x], z)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test)

# expand_dims with dim=None = no-op
x = xtensor("x", dims=("a",), shape=(3,))
y = expand_dims(x, None)
fn = xr_function([x], y)
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test), x_test)

# broadcast after symbolic size
size_sym = scalar("size_sym", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym)
z = y + y # triggers shape alignment
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
x_test = xr_arange_like(x)
out = fn(x_test, 1)
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
xr_assert_allclose(out, expected)


def test_expand_dims_errors():
"""Test error handling in expand_dims."""

# Expanding existing dim
x = xtensor("x", dims=("city",), shape=(3,))
y = expand_dims(x, "country")
with pytest.raises(ValueError, match="already exists"):
expand_dims(y, "city")

# Size = 0 is invalid
with pytest.raises(ValueError, match="size must be.*positive"):
expand_dims(x, "batch", size=0)

# Invalid dim type
with pytest.raises(TypeError):
expand_dims(x, 123)

# Invalid size type
with pytest.raises(TypeError):
expand_dims(x, "new", size=[1])

# Duplicate dimension creation
y = expand_dims(x, "new")
with pytest.raises(ValueError):
expand_dims(y, "new")

# Symbolic size with invalid runtime value
size_sym = scalar("size_sym", dtype="int64")
y = expand_dims(x, "batch", size=size_sym)
fn = xr_function([x, size_sym], y, on_unused_input="ignore")
with pytest.raises(Exception):
fn(xr_arange_like(x), 0)