Skip to content

Commit 19361d6

Browse files
committed
Adding expand_dims for xtensor
1 parent 7b8877b commit 19361d6

File tree

3 files changed

+314
-1
lines changed

3 files changed

+314
-1
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import numpy as np
2+
13
from pytensor.graph import node_rewriter
4+
from pytensor.raise_op import Assert
25
from pytensor.tensor import (
36
broadcast_to,
7+
expand_dims,
8+
gt,
49
join,
510
moveaxis,
611
specify_shape,
@@ -10,6 +15,7 @@
1015
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
1116
from pytensor.xtensor.shape import (
1217
Concat,
18+
ExpandDims,
1319
Squeeze,
1420
Stack,
1521
Transpose,
@@ -132,3 +138,29 @@ def local_squeeze_reshape(fgraph, node):
132138

133139
new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
134140
return [new_out]
141+
142+
143+
@register_lower_xtensor
144+
@node_rewriter([ExpandDims])
145+
def local_expand_dims_reshape(fgraph, node):
146+
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape."""
147+
[x] = node.inputs
148+
x_tensor = tensor_from_xtensor(x)
149+
x_tensor_expanded = expand_dims(x_tensor, axis=0)
150+
151+
target_shape = node.outputs[0].type.shape
152+
153+
size = getattr(node.op, "size", 1)
154+
if isinstance(size, int | np.integer):
155+
if size != 1 and None not in target_shape:
156+
x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape)
157+
else:
158+
# Symbolic size: enforce shape so broadcast happens downstream correctly
159+
# Also validate that size is positive
160+
x_tensor_expanded = Assert(msg="size must be positive")(
161+
x_tensor_expanded, gt(size, 0)
162+
)
163+
x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape)
164+
165+
new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims)
166+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from types import EllipsisType
44
from typing import Literal
55

6+
import numpy as np
7+
68
from pytensor.graph import Apply
79
from pytensor.scalar import discrete_dtypes, upcast
810
from pytensor.tensor import as_tensor, get_scalar_constant_value
@@ -380,3 +382,75 @@ def squeeze(x, dim=None):
380382
return x # no-op if nothing to squeeze
381383

382384
return Squeeze(dims=dims)(x)
385+
386+
387+
class ExpandDims(XOp):
388+
"""Add a new dimension to an XTensorVariable."""
389+
390+
__props__ = ("dims", "size")
391+
392+
def __init__(self, dim, size=1):
393+
self.dims = dim
394+
self.size = size
395+
396+
def make_node(self, x):
397+
x = as_xtensor(x)
398+
399+
if self.dims is None:
400+
# No-op: return same variable
401+
return Apply(self, [x], [x])
402+
403+
# Insert new dim at front
404+
new_dims = (self.dims, *x.type.dims)
405+
406+
# Determine shape
407+
if isinstance(self.size, int | np.integer):
408+
new_shape = (self.size, *x.type.shape)
409+
else:
410+
new_shape = (None, *x.type.shape) # symbolic size
411+
412+
out = xtensor(
413+
dtype=x.type.dtype,
414+
shape=new_shape,
415+
dims=new_dims,
416+
)
417+
return Apply(self, [x], [out])
418+
419+
420+
def expand_dims(x, dim: str | None, size=1):
421+
"""Add a new dimension to an XTensorVariable.
422+
423+
Parameters
424+
----------
425+
x : XTensorVariable
426+
Input tensor
427+
dim : str or None
428+
Name of new dimension. If None, returns x unchanged.
429+
size : int or symbolic, optional
430+
Size of the new dimension (default 1)
431+
432+
Returns
433+
-------
434+
XTensorVariable
435+
Tensor with the new dimension inserted
436+
"""
437+
x = as_xtensor(x)
438+
439+
if dim is None:
440+
return x # No-op
441+
442+
if not isinstance(dim, str):
443+
raise TypeError(f"`dim` must be a string or None, got: {type(dim)}")
444+
445+
if dim in x.type.dims:
446+
raise ValueError(f"Dimension {dim} already exists in {x.type.dims}")
447+
448+
if isinstance(size, int | np.integer):
449+
if size <= 0:
450+
raise ValueError(f"size must be positive, got: {size}")
451+
elif not (
452+
hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar
453+
):
454+
raise TypeError(f"size must be an int or scalar variable, got: {type(size)}")
455+
456+
return ExpandDims(dim=dim, size=size)(x)

tests/xtensor/test_shape.py

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11-
import pytest
11+
import xarray as xr
1212
from xarray import DataArray
1313
from xarray import concat as xr_concat
1414

15+
from pytensor.tensor import scalar
1516
from pytensor.xtensor.shape import (
1617
concat,
18+
expand_dims,
1719
squeeze,
1820
stack,
1921
transpose,
@@ -301,6 +303,15 @@ def test_squeeze_explicit_dims():
301303
fn3d = xr_function([x3], y3d)
302304
xr_assert_allclose(fn3d(x3_test), x3_test)
303305

306+
# Reversibility with expand_dims
307+
x6 = xtensor("x6", dims=("a", "b", "c"), shape=(2, 1, 3))
308+
y6 = squeeze(x6, "b")
309+
# First expand_dims adds at front, then transpose puts it in the right place
310+
z6 = transpose(expand_dims(y6, "b"), "a", "b", "c")
311+
fn6 = xr_function([x6], z6)
312+
x6_test = xr_arange_like(x6)
313+
xr_assert_allclose(fn6(x6_test), x6_test)
314+
304315

305316
def test_squeeze_implicit_dims():
306317
"""Test squeeze with implicit dim=None (all size-1 dimensions)."""
@@ -369,3 +380,199 @@ def test_squeeze_errors():
369380
fn2 = xr_function([x2], y2)
370381
with pytest.raises(Exception):
371382
fn2(x2_test)
383+
384+
385+
def test_expand_dims_explicit():
386+
"""Test expand_dims with explicitly named dimensions and sizes."""
387+
388+
# 1D case
389+
x = xtensor("x", dims=("city",), shape=(3,))
390+
y = expand_dims(x, "country")
391+
fn = xr_function([x], y)
392+
x_xr = xr_arange_like(x)
393+
xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country"))
394+
395+
# 2D case
396+
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
397+
y = expand_dims(x, "country")
398+
fn = xr_function([x], y)
399+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
400+
401+
# 3D case
402+
x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2))
403+
y = expand_dims(x, "country")
404+
fn = xr_function([x], y)
405+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
406+
407+
# Prepending various dims
408+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
409+
for new_dim in ("x", "y", "z"):
410+
y = expand_dims(x, new_dim)
411+
assert y.type.dims == (new_dim, "a", "b")
412+
assert y.type.shape == (1, 2, 3)
413+
414+
# Explicit size=1 behaves like default
415+
y1 = expand_dims(x, "batch", size=1)
416+
y2 = expand_dims(x, "batch")
417+
fn1 = xr_function([x], y1)
418+
fn2 = xr_function([x], y2)
419+
x_test = xr_arange_like(x)
420+
xr_assert_allclose(fn1(x_test), fn2(x_test))
421+
422+
# Scalar expansion
423+
x = xtensor("x", dims=(), shape=())
424+
y = expand_dims(x, "batch")
425+
assert y.type.dims == ("batch",)
426+
assert y.type.shape == (1,)
427+
fn = xr_function([x], y)
428+
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch"))
429+
430+
# Static size > 1: broadcast
431+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
432+
y = expand_dims(x, "batch", size=4)
433+
fn = xr_function([x], y)
434+
expected = xr.DataArray(
435+
np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)),
436+
dims=("batch", "a", "b"),
437+
coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]},
438+
)
439+
xr_assert_allclose(fn(xr_arange_like(x)), expected)
440+
441+
# Insert new dim between existing dims
442+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
443+
y = expand_dims(x, "new")
444+
# Insert new dim between a and b: ("a", "new", "b")
445+
y = transpose(y, "a", "new", "b")
446+
fn = xr_function([x], y)
447+
x_test = xr_arange_like(x)
448+
expected = x_test.expand_dims("new").transpose("a", "new", "b")
449+
xr_assert_allclose(fn(x_test), expected)
450+
451+
# Expand with multiple dims
452+
x = xtensor("x", dims=(), shape=())
453+
y = expand_dims(expand_dims(x, "a"), "b")
454+
fn = xr_function([x], y)
455+
expected = xr_arange_like(x).expand_dims("a").expand_dims("b")
456+
xr_assert_allclose(fn(xr_arange_like(x)), expected)
457+
458+
459+
def test_expand_dims_implicit():
460+
"""Test expand_dims with default or symbolic sizes and dim=None."""
461+
462+
# Symbolic size=1: same as default
463+
size_sym_1 = scalar("size_sym_1", dtype="int64")
464+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
465+
y = expand_dims(x, "batch", size=size_sym_1)
466+
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
467+
x_test = xr_arange_like(x)
468+
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))
469+
470+
# Symbolic size > 1 (but expand only adds dim=1)
471+
size_sym_4 = scalar("size_sym_4", dtype="int64")
472+
y = expand_dims(x, "batch", size=size_sym_4)
473+
fn = xr_function([x, size_sym_4], y, on_unused_input="ignore")
474+
xr_assert_allclose(fn(x_test, 4), x_test.expand_dims("batch"))
475+
476+
# Symbolic size > 1 with broadcasting
477+
size_sym_4 = scalar("size_sym_4", dtype="int64")
478+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
479+
y = expand_dims(x, "batch", size=size_sym_4)
480+
z = y + y # This should broadcast along the batch dimension
481+
fn = xr_function([x, size_sym_4], z, on_unused_input="ignore")
482+
x_test = xr_arange_like(x)
483+
out = fn(x_test, 4)
484+
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
485+
xr_assert_allclose(out, expected)
486+
487+
# Symbolic size with shape validation
488+
size_sym = scalar("size_sym", dtype="int64")
489+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
490+
y = expand_dims(x, "batch", size=size_sym)
491+
z = y + y # This should validate the shape
492+
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
493+
x_test = xr_arange_like(x)
494+
out = fn(x_test, 4)
495+
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
496+
xr_assert_allclose(out, expected)
497+
498+
# Symbolic size with subsequent operations
499+
size_sym = scalar("size_sym", dtype="int64")
500+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
501+
y = expand_dims(x, "batch", size=size_sym)
502+
z = y.sum("batch") # This should work with symbolic size
503+
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
504+
x_test = xr_arange_like(x)
505+
out = fn(x_test, 4)
506+
expected = x_test.expand_dims("batch").sum("batch")
507+
xr_assert_allclose(out, expected)
508+
509+
# Symbolic size with transpose and broadcasting
510+
size_sym = scalar("size_sym", dtype="int64")
511+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
512+
y = expand_dims(x, "batch", size=size_sym)
513+
z = transpose(y, "batch", "a", "b") # This should work with symbolic size
514+
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
515+
x_test = xr_arange_like(x)
516+
out = fn(x_test, 4)
517+
expected = x_test.expand_dims("batch").transpose("batch", "a", "b")
518+
xr_assert_allclose(out, expected)
519+
520+
# Reversibility: expand then squeeze
521+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
522+
y = expand_dims(x, "batch")
523+
z = squeeze(y, "batch")
524+
fn = xr_function([x], z)
525+
x_test = xr_arange_like(x)
526+
xr_assert_allclose(fn(x_test), x_test)
527+
528+
# expand_dims with dim=None = no-op
529+
x = xtensor("x", dims=("a",), shape=(3,))
530+
y = expand_dims(x, None)
531+
fn = xr_function([x], y)
532+
x_test = xr_arange_like(x)
533+
xr_assert_allclose(fn(x_test), x_test)
534+
535+
# broadcast after symbolic size
536+
size_sym = scalar("size_sym", dtype="int64")
537+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
538+
y = expand_dims(x, "batch", size=size_sym)
539+
z = y + y # triggers shape alignment
540+
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
541+
x_test = xr_arange_like(x)
542+
out = fn(x_test, 1)
543+
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
544+
xr_assert_allclose(out, expected)
545+
546+
547+
def test_expand_dims_errors():
548+
"""Test error handling in expand_dims."""
549+
550+
# Expanding existing dim
551+
x = xtensor("x", dims=("city",), shape=(3,))
552+
y = expand_dims(x, "country")
553+
with pytest.raises(ValueError, match="already exists"):
554+
expand_dims(y, "city")
555+
556+
# Size = 0 is invalid
557+
with pytest.raises(ValueError, match="size must be.*positive"):
558+
expand_dims(x, "batch", size=0)
559+
560+
# Invalid dim type
561+
with pytest.raises(TypeError):
562+
expand_dims(x, 123)
563+
564+
# Invalid size type
565+
with pytest.raises(TypeError):
566+
expand_dims(x, "new", size=[1])
567+
568+
# Duplicate dimension creation
569+
y = expand_dims(x, "new")
570+
with pytest.raises(ValueError):
571+
expand_dims(y, "new")
572+
573+
# Symbolic size with invalid runtime value
574+
size_sym = scalar("size_sym", dtype="int64")
575+
y = expand_dims(x, "batch", size=size_sym)
576+
fn = xr_function([x, size_sym], y, on_unused_input="ignore")
577+
with pytest.raises(Exception):
578+
fn(xr_arange_like(x), 0)

0 commit comments

Comments
 (0)