Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,24 +354,29 @@ def make_node(self, x):
return Apply(self, [x], [out])


def squeeze(x, dim=None):
"""Remove dimensions of size 1 from an XTensorVariable.
def squeeze(x, dim=None, drop=False, axis=None):
"""Remove dimensions of size 1 from an XTensorVariable."""
x = as_xtensor(x)

Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
# drop parameter is ignored in pytensor.xtensor
if drop is not None:
warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning)

Returns
-------
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
x = as_xtensor(x)
# dim and axis are mutually exclusive
if dim is not None and axis is not None:
raise ValueError("Cannot specify both `dim` and `axis`")

# if axis is specified, it must be a sequence of ints
if axis is not None:
if not isinstance(axis, Sequence):
axis = [axis]
if not all(isinstance(a, int) for a in axis):
raise ValueError("axis must be an integer or a sequence of integers")

# convert axis to dims
dims = tuple(x.type.dims[i] for i in axis)

# if dim is specified, it must be a string or a sequence of strings
if dim is None:
dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1)
elif isinstance(dim, str):
Expand Down
26 changes: 21 additions & 5 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,17 +474,33 @@ def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
def squeeze(
self,
dim: Sequence[str] | str | None = None,
drop: bool = False,
drop=None,
axis: int | Sequence[int] | None = None,
):
if axis is not None:
raise NotImplementedError("Squeeze with axis not Implemented")
return px.shape.squeeze(self, dim)
"""Remove dimensions of size 1 from an XTensorVariable.

Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
If drop=True, drop squeezed coordinates instead of making them scalar.
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
-------
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
return px.shape.squeeze(self, dim, drop, axis)

def expand_dims(
self,
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
create_index_for_new_dim: bool = True,
create_index_for_new_dim=None,
axis: int | Sequence[int] | None = None,
**dim_kwargs,
):
Expand Down
121 changes: 50 additions & 71 deletions tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
squeeze,
stack,
transpose,
unstack,
Expand Down Expand Up @@ -265,89 +264,74 @@ def test_concat_scalar():
xr_assert_allclose(res, expected_res)


def test_squeeze_explicit_dims():
"""Test squeeze with explicit dimension(s)."""
def test_squeeze():
"""Test squeeze."""

# Single dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
y1 = squeeze(x1, "country")
y1 = x1.squeeze("country")
fn1 = xr_function([x1], y1)
x1_test = xr_arange_like(x1)
xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country"))

# Multiple dimensions
# Multiple dimensions and order independence
x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3))
y2 = squeeze(x2, ["b", "c"])
fn2 = xr_function([x2], y2)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"]))

# Order independence
x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1))
y3a = squeeze(x3, ["b", "c"])
y3b = squeeze(x3, ["c", "b"])
fn3a = xr_function([x3], y3a)
fn3b = xr_function([x3], y3b)
x3_test = xr_arange_like(x3)
xr_assert_allclose(fn3a(x3_test), fn3b(x3_test))

# Redundant dimensions
y3c = squeeze(x3, ["b", "b"])
fn3c = xr_function([x3], y3c)
xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"]))

# Empty list = no-op
y3d = squeeze(x3, [])
fn3d = xr_function([x3], y3d)
xr_assert_allclose(fn3d(x3_test), x3_test)


def test_squeeze_implicit_dims():
"""Test squeeze with implicit dim=None (all size-1 dimensions)."""

# All dimensions size 1
x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1))
y1 = squeeze(x1)
fn1 = xr_function([x1], y1)
x1_test = xr_arange_like(x1)
xr_assert_allclose(fn1(x1_test), x1_test.squeeze())

# No dimensions size 1 = no-op
x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4))
y2 = squeeze(x2)
fn2 = xr_function([x2], y2)
y2a = x2.squeeze(["b", "c"])
y2b = x2.squeeze(["c", "b"]) # Test order independence
y2c = x2.squeeze(["b", "b"]) # Test redundant dimensions
y2d = x2.squeeze([]) # Test empty list (no-op)
fn2a = xr_function([x2], y2a)
fn2b = xr_function([x2], y2b)
fn2c = xr_function([x2], y2c)
fn2d = xr_function([x2], y2d)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn2(x2_test), x2_test)
xr_assert_allclose(fn2a(x2_test), x2_test.squeeze(["b", "c"]))
xr_assert_allclose(fn2b(x2_test), x2_test.squeeze(["c", "b"]))
xr_assert_allclose(fn2c(x2_test), x2_test.squeeze(["b", "b"]))
xr_assert_allclose(fn2d(x2_test), x2_test)

# Symbolic shape where runtime shape is 1 → should squeeze
# Unknown shapes
x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown
y3 = squeeze(x3, "b")
y3 = x3.squeeze("b")
x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3)))
fn3 = xr_function([x3], y3)
xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b"))

# Mixed static + symbolic shapes, where symbolic shape is 1
# Mixed known + unknown shapes
x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3))
y4 = squeeze(x4, "b")
y4 = x4.squeeze("b")
x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3)))
fn4 = xr_function([x4], y4)
xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b"))

"""
This test documents that we intentionally don't squeeze dimensions with symbolic shapes
(static_shape=None) even when they are 1 at runtime, while xarray does squeeze them.
"""
# Create a tensor with a symbolic dimension that will be 1 at runtime
x = xtensor("x", dims=("a", "b", "c")) # shape unknown
y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3)))
fn = xr_function([x], y)
res = fn(x_test)
# Test axis parameter
x5 = xtensor("x5", dims=("a", "b", "c"), shape=(2, 1, 3))
y5 = x5.squeeze(axis=1) # squeeze dimension at index 1 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1))

# Test axis parameter with negative index
y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b)
fn5 = xr_function([x5], y5)
x5_test = xr_arange_like(x5)
xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2))

# Test axis parameter with sequence of ints
y6 = x2.squeeze(axis=[1, 2])
fn6 = xr_function([x2], y6)
x2_test = xr_arange_like(x2)
xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2]))

# Our implementation should not squeeze the symbolic dimension
assert "b" in res.dims
# While xarray would squeeze it
assert "b" not in x_test.squeeze().dims
# Test drop parameter warning
x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1))
with pytest.warns(
UserWarning, match="drop parameter has no effect in pytensor.xtensor"
):
y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate
fn7 = xr_function([x7], y7)
x7_test = xr_arange_like(x7)
xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True))


def test_squeeze_errors():
Expand All @@ -356,15 +340,15 @@ def test_squeeze_errors():
# Non-existent dimension
x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1))
with pytest.raises(ValueError, match="Dimension .* not found"):
squeeze(x1, "time")
x1.squeeze("time")

# Dimension size > 1
with pytest.raises(ValueError, match="has static size .* not 1"):
squeeze(x1, "city")
x1.squeeze("city")

# Symbolic shape: dim is not 1 at runtime → should raise
x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown
y2 = squeeze(x2, "b")
y2 = x2.squeeze("b")
x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3)))
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
Expand Down Expand Up @@ -471,11 +455,6 @@ def test_expand_dims_errors():
with pytest.raises(ValueError, match="already exists"):
y.expand_dims("new")

# Find out what xarray does with a numpy array as dim
# x_test = xr_arange_like(x)
# x_test.expand_dims(np.array([1, 2]))
# TypeError: unhashable type: 'numpy.ndarray'

# Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"):
y.expand_dims(np.array([1, 2]))