Skip to content

Commit 260b9b6

Browse files
committed
Removing expand_dims
1 parent 1024798 commit 260b9b6

File tree

3 files changed

+7
-280
lines changed

3 files changed

+7
-280
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
import numpy as np
2-
31
from pytensor.graph import node_rewriter
4-
from pytensor.raise_op import Assert
52
from pytensor.tensor import (
63
broadcast_to,
7-
expand_dims,
8-
gt,
94
join,
105
moveaxis,
116
specify_shape,
@@ -15,7 +10,6 @@
1510
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
1611
from pytensor.xtensor.shape import (
1712
Concat,
18-
ExpandDims,
1913
Squeeze,
2014
Stack,
2115
Transpose,
@@ -125,40 +119,6 @@ def lower_transpose(fgraph, node):
125119
return [new_out]
126120

127121

128-
@register_xcanonicalize
129-
@node_rewriter([ExpandDims])
130-
def local_expand_dims_reshape(fgraph, node):
131-
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape."""
132-
if not isinstance(node.op, ExpandDims):
133-
return False
134-
135-
x = node.inputs[0]
136-
dim = node.op.dims
137-
size = getattr(node.op, "size", 1)
138-
139-
if dim is None:
140-
return [x]
141-
142-
x_tensor = tensor_from_xtensor(x)
143-
x_tensor_expanded = expand_dims(x_tensor, axis=0)
144-
145-
target_shape = node.outputs[0].type.shape
146-
147-
if isinstance(size, int | np.integer):
148-
if size != 1 and None not in target_shape:
149-
x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape)
150-
else:
151-
# Symbolic size: enforce shape so broadcast happens downstream correctly
152-
# Also validate that size is positive
153-
x_tensor_expanded = Assert(msg="size must be positive")(
154-
x_tensor_expanded, gt(size, 0)
155-
)
156-
x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape)
157-
158-
new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims)
159-
return [new_out]
160-
161-
162122
@register_xcanonicalize
163123
@node_rewriter([Squeeze])
164124
def local_squeeze_reshape(fgraph, node):

pytensor/xtensor/shape.py

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from collections.abc import Sequence
33
from typing import Literal
44

5-
import numpy as np
6-
75
from pytensor import Variable
86
from pytensor.graph import Apply
97
from pytensor.scalar import discrete_dtypes, upcast
@@ -305,84 +303,6 @@ def concat(xtensors, dim: str):
305303
return Concat(dim=dim)(*xtensors)
306304

307305

308-
class ExpandDims(XOp):
309-
"""Add a new dimension to an XTensorVariable."""
310-
311-
__props__ = ("dims", "size")
312-
313-
def __init__(self, dim, size=1):
314-
self.dims = dim
315-
self.size = size
316-
317-
def make_node(self, x):
318-
x = as_xtensor(x)
319-
320-
if self.dims is None:
321-
# No-op: return same variable
322-
return Apply(self, [x], [x])
323-
324-
# Insert new dim at front
325-
new_dims = (self.dims, *x.type.dims)
326-
327-
# Determine shape
328-
if isinstance(self.size, int | np.integer):
329-
new_shape = (self.size, *x.type.shape)
330-
else:
331-
new_shape = (None, *x.type.shape) # symbolic size
332-
333-
out = xtensor(
334-
dtype=x.type.dtype,
335-
shape=new_shape,
336-
dims=new_dims,
337-
)
338-
return Apply(self, [x], [out])
339-
340-
def infer_shape(self, fgraph, node, input_shapes):
341-
(input_shape,) = input_shapes
342-
if self.dims is None:
343-
return [input_shape]
344-
return [(self.size, *list(input_shape))]
345-
346-
347-
def expand_dims(x, dim: str | None, size=1):
348-
"""Add a new dimension to an XTensorVariable.
349-
350-
Parameters
351-
----------
352-
x : XTensorVariable
353-
Input tensor
354-
dim : str or None
355-
Name of new dimension. If None, returns x unchanged.
356-
size : int or symbolic, optional
357-
Size of the new dimension (default 1)
358-
359-
Returns
360-
-------
361-
XTensorVariable
362-
Tensor with the new dimension inserted
363-
"""
364-
x = as_xtensor(x)
365-
366-
if dim is None:
367-
return x # No-op
368-
369-
if not isinstance(dim, str):
370-
raise TypeError(f"`dim` must be a string or None, got: {type(dim)}")
371-
372-
if dim in x.type.dims:
373-
raise ValueError(f"Dimension {dim} already exists in {x.type.dims}")
374-
375-
if isinstance(size, int | np.integer):
376-
if size <= 0:
377-
raise ValueError(f"size must be positive, got: {size}")
378-
elif not (
379-
hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar
380-
):
381-
raise TypeError(f"size must be an int or scalar variable, got: {type(size)}")
382-
383-
return ExpandDims(dim=dim, size=size)(x)
384-
385-
386306
class Squeeze(XOp):
387307
"""Remove specified dimensions from an XTensorVariable.
388308

tests/xtensor/test_shape.py

Lines changed: 7 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,11 @@
99

1010
import numpy as np
1111
import pytest
12-
import xarray as xr
1312
from xarray import DataArray
1413
from xarray import concat as xr_concat
1514

16-
from pytensor.tensor import scalar
1715
from pytensor.xtensor.shape import (
1816
concat,
19-
expand_dims,
2017
squeeze,
2118
stack,
2219
transpose,
@@ -269,156 +266,6 @@ def test_concat_scalar():
269266
xr_assert_allclose(res, expected_res)
270267

271268

272-
def test_expand_dims_explicit():
273-
"""Test expand_dims with explicitly named dimensions and sizes."""
274-
275-
# 1D case
276-
x = xtensor("x", dims=("city",), shape=(3,))
277-
y = expand_dims(x, "country")
278-
fn = xr_function([x], y)
279-
x_xr = xr_arange_like(x)
280-
xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country"))
281-
282-
# 2D case
283-
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
284-
y = expand_dims(x, "country")
285-
fn = xr_function([x], y)
286-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
287-
288-
# 3D case
289-
x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2))
290-
y = expand_dims(x, "country")
291-
fn = xr_function([x], y)
292-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
293-
294-
# Prepending various dims
295-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
296-
for new_dim in ("x", "y", "z"):
297-
y = expand_dims(x, new_dim)
298-
assert y.type.dims == (new_dim, "a", "b")
299-
assert y.type.shape == (1, 2, 3)
300-
301-
# Explicit size=1 behaves like default
302-
y1 = expand_dims(x, "batch", size=1)
303-
y2 = expand_dims(x, "batch")
304-
fn1 = xr_function([x], y1)
305-
fn2 = xr_function([x], y2)
306-
x_test = xr_arange_like(x)
307-
xr_assert_allclose(fn1(x_test), fn2(x_test))
308-
309-
# Scalar expansion
310-
x = xtensor("x", dims=(), shape=())
311-
y = expand_dims(x, "batch")
312-
assert y.type.dims == ("batch",)
313-
assert y.type.shape == (1,)
314-
fn = xr_function([x], y)
315-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch"))
316-
317-
# Static size > 1: broadcast
318-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
319-
y = expand_dims(x, "batch", size=4)
320-
fn = xr_function([x], y)
321-
expected = xr.DataArray(
322-
np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)),
323-
dims=("batch", "a", "b"),
324-
coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]},
325-
)
326-
xr_assert_allclose(fn(xr_arange_like(x)), expected)
327-
328-
# Insert new dim between existing dims
329-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
330-
y = expand_dims(x, "new")
331-
# Insert new dim between a and b: ("a", "new", "b")
332-
y = transpose(y, "a", "new", "b")
333-
fn = xr_function([x], y)
334-
x_test = xr_arange_like(x)
335-
expected = x_test.expand_dims("new").transpose("a", "new", "b")
336-
xr_assert_allclose(fn(x_test), expected)
337-
338-
# Expand with multiple dims
339-
x = xtensor("x", dims=(), shape=())
340-
y = expand_dims(expand_dims(x, "a"), "b")
341-
fn = xr_function([x], y)
342-
expected = xr_arange_like(x).expand_dims("a").expand_dims("b")
343-
xr_assert_allclose(fn(xr_arange_like(x)), expected)
344-
345-
346-
def test_expand_dims_implicit():
347-
"""Test expand_dims with default or symbolic sizes and dim=None."""
348-
349-
# Symbolic size=1: same as default
350-
size_sym_1 = scalar("size_sym_1", dtype="int64")
351-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
352-
y = expand_dims(x, "batch", size=size_sym_1)
353-
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
354-
expected = xr_arange_like(x).expand_dims("batch")
355-
xr_assert_allclose(fn(xr_arange_like(x), 1), expected)
356-
357-
# Symbolic size > 1 (but expand only adds dim=1)
358-
size_sym_4 = scalar("size_sym_4", dtype="int64")
359-
y = expand_dims(x, "batch", size=size_sym_4)
360-
fn = xr_function([x, size_sym_4], y, on_unused_input="ignore")
361-
xr_assert_allclose(fn(xr_arange_like(x), 4), expected)
362-
363-
# Reversibility: expand then squeeze
364-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
365-
y = expand_dims(x, "batch")
366-
z = squeeze(y, "batch")
367-
fn = xr_function([x], z)
368-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x))
369-
370-
# expand_dims with dim=None = no-op
371-
x = xtensor("x", dims=("a",), shape=(3,))
372-
y = expand_dims(x, None)
373-
fn = xr_function([x], y)
374-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x))
375-
376-
# broadcast after symbolic size
377-
size_sym = scalar("size_sym", dtype="int64")
378-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
379-
y = expand_dims(x, "batch", size=size_sym)
380-
z = y + y # triggers shape alignment
381-
fn = xr_function([x, size_sym], z, on_unused_input="ignore")
382-
x_test = xr_arange_like(x)
383-
out = fn(x_test, 1)
384-
expected = x_test.expand_dims("batch") + x_test.expand_dims("batch")
385-
xr_assert_allclose(out, expected)
386-
387-
388-
def test_expand_dims_errors():
389-
"""Test error handling in expand_dims."""
390-
391-
# Expanding existing dim
392-
x = xtensor("x", dims=("city",), shape=(3,))
393-
y = expand_dims(x, "country")
394-
with pytest.raises(ValueError, match="already exists"):
395-
expand_dims(y, "city")
396-
397-
# Size = 0 is invalid
398-
with pytest.raises(ValueError, match="size must be.*positive"):
399-
expand_dims(x, "batch", size=0)
400-
401-
# Invalid dim type
402-
with pytest.raises(TypeError):
403-
expand_dims(x, 123)
404-
405-
# Invalid size type
406-
with pytest.raises(TypeError):
407-
expand_dims(x, "new", size=[1])
408-
409-
# Duplicate dimension creation
410-
y = expand_dims(x, "new")
411-
with pytest.raises(ValueError):
412-
expand_dims(y, "new")
413-
414-
# Symbolic size with invalid runtime value
415-
size_sym = scalar("size_sym", dtype="int64")
416-
y = expand_dims(x, "batch", size=size_sym)
417-
fn = xr_function([x, size_sym], y, on_unused_input="ignore")
418-
with pytest.raises(Exception):
419-
fn(xr_arange_like(x), 0)
420-
421-
422269
def test_squeeze_explicit_dims():
423270
"""Test squeeze with explicit dimension(s)."""
424271

@@ -487,13 +334,13 @@ def test_squeeze_implicit_dims():
487334
fn4 = xr_function([x4], y4)
488335
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))
489336

490-
# Reversibility with expand_dims
491-
x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3))
492-
y5 = squeeze(x5, "time")
493-
z5 = expand_dims(y5, "time")
494-
fn5 = xr_function([x5], z5)
495-
x5_test = xr_arange_like(x5)
496-
xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test)
337+
# Reversibility with expand_dims (restore when expand_dims is implemented)
338+
# x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3))
339+
# y5 = squeeze(x5, "time")
340+
# z5 = expand_dims(y5, "time")
341+
# fn5 = xr_function([x5], z5)
342+
# x5_test = xr_arange_like(x5)
343+
# xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test)
497344

498345
"""
499346
This test documents that we intentionally don't squeeze dimensions with symbolic shapes

0 commit comments

Comments
 (0)