Skip to content

Commit 7c25746

Browse files
committed
Creating new branch for expand_dims
1 parent 1ff26c2 commit 7c25746

File tree

3 files changed

+273
-1
lines changed

3 files changed

+273
-1
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import numpy as np
2+
13
from pytensor.graph import node_rewriter
4+
from pytensor.raise_op import Assert
25
from pytensor.tensor import (
36
broadcast_to,
7+
expand_dims,
8+
gt,
49
join,
510
moveaxis,
611
specify_shape,
@@ -10,6 +15,7 @@
1015
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
1116
from pytensor.xtensor.shape import (
1217
Concat,
18+
ExpandDims,
1319
Squeeze,
1420
Stack,
1521
Transpose,
@@ -132,3 +138,37 @@ def local_squeeze_reshape(fgraph, node):
132138

133139
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134140
return [new_out]
141+
142+
143+
@register_xcanonicalize
144+
@node_rewriter([ExpandDims])
145+
def local_expand_dims_reshape(fgraph, node):
146+
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape."""
147+
if not isinstance(node.op, ExpandDims):
148+
return False
149+
150+
x = node.inputs[0]
151+
dim = node.op.dims
152+
size = getattr(node.op, "size", 1)
153+
154+
if dim is None:
155+
return [x]
156+
157+
x_tensor = tensor_from_xtensor(x)
158+
x_tensor_expanded = expand_dims(x_tensor, axis=0)
159+
160+
target_shape = node.outputs[0].type.shape
161+
162+
if isinstance(size, int | np.integer):
163+
if size != 1 and None not in target_shape:
164+
x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape)
165+
else:
166+
# Symbolic size: enforce shape so broadcast happens downstream correctly
167+
# Also validate that size is positive
168+
x_tensor_expanded = Assert(msg="size must be positive")(
169+
x_tensor_expanded, gt(size, 0)
170+
)
171+
x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape)
172+
173+
new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims)
174+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections.abc import Sequence
33
from typing import Literal
44

5+
import numpy as np
6+
57
from pytensor import Variable
68
from pytensor.graph import Apply
79
from pytensor.scalar import discrete_dtypes, upcast
@@ -380,3 +382,81 @@ def squeeze(x, dim=None):
380382
return x # no-op if nothing to squeeze
381383

382384
return Squeeze(dims=dims)(x)
385+
386+
387+
class ExpandDims(XOp):
388+
"""Add a new dimension to an XTensorVariable."""
389+
390+
__props__ = ("dims", "size")
391+
392+
def __init__(self, dim, size=1):
393+
self.dims = dim
394+
self.size = size
395+
396+
def make_node(self, x):
397+
x = as_xtensor(x)
398+
399+
if self.dims is None:
400+
# No-op: return same variable
401+
return Apply(self, [x], [x])
402+
403+
# Insert new dim at front
404+
new_dims = (self.dims, *x.type.dims)
405+
406+
# Determine shape
407+
if isinstance(self.size, int | np.integer):
408+
new_shape = (self.size, *x.type.shape)
409+
else:
410+
new_shape = (None, *x.type.shape) # symbolic size
411+
412+
out = xtensor(
413+
dtype=x.type.dtype,
414+
shape=new_shape,
415+
dims=new_dims,
416+
)
417+
return Apply(self, [x], [out])
418+
419+
def infer_shape(self, fgraph, node, input_shapes):
420+
(input_shape,) = input_shapes
421+
if self.dims is None:
422+
return [input_shape]
423+
return [(self.size, *list(input_shape))]
424+
425+
426+
def expand_dims(x, dim: str | None, size=1):
427+
"""Add a new dimension to an XTensorVariable.
428+
429+
Parameters
430+
----------
431+
x : XTensorVariable
432+
Input tensor
433+
dim : str or None
434+
Name of new dimension. If None, returns x unchanged.
435+
size : int or symbolic, optional
436+
Size of the new dimension (default 1)
437+
438+
Returns
439+
-------
440+
XTensorVariable
441+
Tensor with the new dimension inserted
442+
"""
443+
x = as_xtensor(x)
444+
445+
if dim is None:
446+
return x # No-op
447+
448+
if not isinstance(dim, str):
449+
raise TypeError(f"`dim` must be a string or None, got: {type(dim)}")
450+
451+
if dim in x.type.dims:
452+
raise ValueError(f"Dimension {dim} already exists in {x.type.dims}")
453+
454+
if isinstance(size, int | np.integer):
455+
if size <= 0:
456+
raise ValueError(f"size must be positive, got: {size}")
457+
elif not (
458+
hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar
459+
):
460+
raise TypeError(f"size must be an int or scalar variable, got: {type(size)}")
461+
462+
return ExpandDims(dim=dim, size=size)(x)

tests/xtensor/test_shape.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11-
import pytest
11+
import xarray as xr
1212
from xarray import DataArray
1313
from xarray import concat as xr_concat
1414

15+
from pytensor.tensor import scalar
1516
from pytensor.xtensor.shape import (
1617
concat,
18+
expand_dims,
1719
squeeze,
1820
stack,
1921
transpose,
@@ -369,3 +371,153 @@ def test_squeeze_errors():
369371
fn2 = xr_function([x2], y2)
370372
with pytest.raises(Exception):
371373
fn2(x2_test)
374+
375+
376+
def test_expand_dims_explicit():
377+
"""Test expand_dims with explicitly named dimensions and sizes."""
378+
379+
# 1D case
380+
x = xtensor("x", dims=("city",), shape=(3,))
381+
y = expand_dims(x, "country")
382+
fn = xr_function([x], y)
383+
x_xr = xr_arange_like(x)
384+
xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country"))
385+
386+
# 2D case
387+
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
388+
y = expand_dims(x, "country")
389+
fn = xr_function([x], y)
390+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
391+
392+
# 3D case
393+
x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2))
394+
y = expand_dims(x, "country")
395+
fn = xr_function([x], y)
396+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
397+
398+
# Prepending various dims
399+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
400+
for new_dim in ("x", "y", "z"):
401+
y = expand_dims(x, new_dim)
402+
assert y.type.dims == (new_dim, "a", "b")
403+
assert y.type.shape == (1, 2, 3)
404+
405+
# Explicit size=1 behaves like default
406+
y1 = expand_dims(x, "batch", size=1)
407+
y2 = expand_dims(x, "batch")
408+
fn1 = xr_function([x], y1)
409+
fn2 = xr_function([x], y2)
410+
x_test = xr_arange_like(x)
411+
xr_assert_allclose(fn1(x_test), fn2(x_test))
412+
413+
# Scalar expansion
414+
x = xtensor("x", dims=(), shape=())
415+
y = expand_dims(x, "batch")
416+
assert y.type.dims == ("batch",)
417+
assert y.type.shape == (1,)
418+
fn = xr_function([x], y)
419+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch"))
420+
421+
# Static size > 1: broadcast
422+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
423+
y = expand_dims(x, "batch", size=4)
424+
fn = xr_function([x], y)
425+
expected = xr.DataArray(
426+
np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)),
427+
dims=("batch", "a", "b"),
428+
coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]},
429+
)
430+
xr_assert_allclose(fn(xr_arange_like(x)), expected)
431+
432+
# Insert new dim between existing dims
433+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
434+
y = expand_dims(x, "new")
435+
# Insert new dim between a and b: ("a", "new", "b")
436+
y = transpose(y, "a", "new", "b")
437+
fn = xr_function([x], y)
438+
x_test = xr_arange_like(x)
439+
expected = x_test.expand_dims("new").transpose("a", "new", "b")
440+
xr_assert_allclose(fn(x_test), expected)
441+
442+
# Expand with multiple dims
443+
x = xtensor("x", dims=(), shape=())
444+
y = expand_dims(expand_dims(x, "a"), "b")
445+
fn = xr_function([x], y)
446+
expected = xr_arange_like(x).expand_dims("a").expand_dims("b")
447+
xr_assert_allclose(fn(xr_arange_like(x)), expected)
448+
449+
450+
def test_expand_dims_implicit():
451+
"""Test expand_dims with default or symbolic sizes and dim=None."""
452+
453+
# Symbolic size=1: same as default
454+
size_sym_1 = scalar("size_sym_1", dtype="int64")
455+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
456+
y = expand_dims(x, "batch", size=size_sym_1)
457+
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
458+
expected = xr_arange_like(x).expand_dims("batch")
459+
xr_assert_allclose(fn(xr_arange_like(x), 1), expected)
460+
461+
# Symbolic size > 1 (but expand only adds dim=1)
462+
size_sym_4 = scalar("size_sym_4", dtype="int64")
463+
y = expand_dims(x, "batch", size=size_sym_4)
464+
fn = xr_function([x, size_sym_4], y, on_unused_input="ignore")
465+
xr_assert_allclose(fn(xr_arange_like(x), 4), expected)
466+
467+
# Reversibility: expand then squeeze
468+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
469+
y = expand_dims(x, "batch")
470+
z = squeeze(y, "batch")
471+
fn = xr_function([x], z)
472+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x))
473+
474+
# expand_dims with dim=None = no-op
475+
x = xtensor("x", dims=("a",), shape=(3,))
476+
y = expand_dims(x, None)
477+
fn = xr_function([x], y)
478+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x))
479+
480+
# broadcast after symbolic size
481+
size_sym = scalar("size_sym", dtype="int64")
482+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
483+
y = expand_dims(x, "batch", size=size_sym)
484+
z = y + y # triggers shape alignment
485+
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
486+
x_test = xr_arange_like(x)
487+
out = fn(x_test, 1)
488+
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
489+
xr_assert_allclose(out, expected)
490+
491+
492+
def test_expand_dims_errors():
493+
"""Test error handling in expand_dims."""
494+
495+
# Expanding existing dim
496+
x = xtensor("x", dims=("city",), shape=(3,))
497+
y = expand_dims(x, "country")
498+
with pytest.raises(ValueError, match="already exists"):
499+
expand_dims(y, "city")
500+
501+
# Size = 0 is invalid
502+
with pytest.raises(ValueError, match="size must be.*positive"):
503+
expand_dims(x, "batch", size=0)
504+
505+
# Invalid dim type
506+
with pytest.raises(TypeError):
507+
expand_dims(x, 123)
508+
509+
# Invalid size type
510+
with pytest.raises(TypeError):
511+
expand_dims(x, "new", size=[1])
512+
513+
# Duplicate dimension creation
514+
y = expand_dims(x, "new")
515+
with pytest.raises(ValueError):
516+
expand_dims(y, "new")
517+
518+
# Symbolic size with invalid runtime value
519+
size_sym = scalar("size_sym", dtype="int64")
520+
y = expand_dims(x, "batch", size=size_sym)
521+
fn = xr_function([x, size_sym], y, on_unused_input="ignore")
522+
with pytest.raises(Exception):
523+
fn(xr_arange_like(x), 0)

0 commit comments

Comments
 (0)