diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f604dc8188..ceded71ec0 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -354,24 +354,29 @@ def make_node(self, x): return Apply(self, [x], [out]) -def squeeze(x, dim=None): - """Remove dimensions of size 1 from an XTensorVariable. +def squeeze(x, dim=None, drop=False, axis=None): + """Remove dimensions of size 1 from an XTensorVariable.""" + x = as_xtensor(x) - Parameters - ---------- - x : XTensorVariable - The input tensor - dim : str or None or iterable of str, optional - The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 - (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. + # drop parameter is ignored in pytensor.xtensor + if drop is not None: + warnings.warn("drop parameter has no effect in pytensor.xtensor", UserWarning) - Returns - ------- - XTensorVariable - A new tensor with the specified dimension(s) removed. - """ - x = as_xtensor(x) + # dim and axis are mutually exclusive + if dim is not None and axis is not None: + raise ValueError("Cannot specify both `dim` and `axis`") + + # if axis is specified, it must be a sequence of ints + if axis is not None: + if not isinstance(axis, Sequence): + axis = [axis] + if not all(isinstance(a, int) for a in axis): + raise ValueError("axis must be an integer or a sequence of integers") + + # convert axis to dims + dims = tuple(x.type.dims[i] for i in axis) + # if dim is specified, it must be a string or a sequence of strings if dim is None: dims = tuple(d for d, s in zip(x.type.dims, x.type.shape) if s == 1) elif isinstance(dim, str): diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 9fea411129..96b0a1fd7c 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -474,17 +474,33 @@ def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): def squeeze( self, dim: Sequence[str] | str | None = None, - drop: bool = False, + drop=None, axis: int | Sequence[int] | None = None, ): - if axis is not None: - raise NotImplementedError("Squeeze with axis not Implemented") - return px.shape.squeeze(self, dim) + """Remove dimensions of size 1 from an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str or None or iterable of str, optional + The name(s) of the dimension(s) to remove. If None, all dimensions of size 1 + (known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime. + drop : bool, optional + If drop=True, drop squeezed coordinates instead of making them scalar. + axis : int or iterable of int, optional + The axis(es) to remove. If None, all dimensions of size 1 will be removed. + Returns + ------- + XTensorVariable + A new tensor with the specified dimension(s) removed. + """ + return px.shape.squeeze(self, dim, drop, axis) def expand_dims( self, dim: str | Sequence[str] | dict[str, int | Sequence] | None = None, - create_index_for_new_dim: bool = True, + create_index_for_new_dim=None, axis: int | Sequence[int] | None = None, **dim_kwargs, ): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 69802dcec0..b6ec0bdbbc 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -14,7 +14,6 @@ from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, - squeeze, stack, transpose, unstack, @@ -265,89 +264,74 @@ def test_concat_scalar(): xr_assert_allclose(res, expected_res) -def test_squeeze_explicit_dims(): - """Test squeeze with explicit dimension(s).""" +def test_squeeze(): + """Test squeeze.""" # Single dimension x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) - y1 = squeeze(x1, "country") + y1 = x1.squeeze("country") fn1 = xr_function([x1], y1) x1_test = xr_arange_like(x1) xr_assert_allclose(fn1(x1_test), x1_test.squeeze("country")) - # Multiple dimensions + # Multiple dimensions and order independence x2 = xtensor("x2", dims=("a", "b", "c", "d"), shape=(2, 1, 1, 3)) - y2 = squeeze(x2, ["b", "c"]) - fn2 = xr_function([x2], y2) - x2_test = xr_arange_like(x2) - xr_assert_allclose(fn2(x2_test), x2_test.squeeze(["b", "c"])) - - # Order independence - x3 = xtensor("x3", dims=("a", "b", "c"), shape=(2, 1, 1)) - y3a = squeeze(x3, ["b", "c"]) - y3b = squeeze(x3, ["c", "b"]) - fn3a = xr_function([x3], y3a) - fn3b = xr_function([x3], y3b) - x3_test = xr_arange_like(x3) - xr_assert_allclose(fn3a(x3_test), fn3b(x3_test)) - - # Redundant dimensions - y3c = squeeze(x3, ["b", "b"]) - fn3c = xr_function([x3], y3c) - xr_assert_allclose(fn3c(x3_test), x3_test.squeeze(["b", "b"])) - - # Empty list = no-op - y3d = squeeze(x3, []) - fn3d = xr_function([x3], y3d) - xr_assert_allclose(fn3d(x3_test), x3_test) - - -def test_squeeze_implicit_dims(): - """Test squeeze with implicit dim=None (all size-1 dimensions).""" - - # All dimensions size 1 - x1 = xtensor("x1", dims=("a", "b"), shape=(1, 1)) - y1 = squeeze(x1) - fn1 = xr_function([x1], y1) - x1_test = xr_arange_like(x1) - xr_assert_allclose(fn1(x1_test), x1_test.squeeze()) - - # No dimensions size 1 = no-op - x2 = xtensor("x2", dims=("row", "col", "batch"), shape=(2, 3, 4)) - y2 = squeeze(x2) - fn2 = xr_function([x2], y2) + y2a = x2.squeeze(["b", "c"]) + y2b = x2.squeeze(["c", "b"]) # Test order independence + y2c = x2.squeeze(["b", "b"]) # Test redundant dimensions + y2d = x2.squeeze([]) # Test empty list (no-op) + fn2a = xr_function([x2], y2a) + fn2b = xr_function([x2], y2b) + fn2c = xr_function([x2], y2c) + fn2d = xr_function([x2], y2d) x2_test = xr_arange_like(x2) - xr_assert_allclose(fn2(x2_test), x2_test) + xr_assert_allclose(fn2a(x2_test), x2_test.squeeze(["b", "c"])) + xr_assert_allclose(fn2b(x2_test), x2_test.squeeze(["c", "b"])) + xr_assert_allclose(fn2c(x2_test), x2_test.squeeze(["b", "b"])) + xr_assert_allclose(fn2d(x2_test), x2_test) - # Symbolic shape where runtime shape is 1 → should squeeze + # Unknown shapes x3 = xtensor("x3", dims=("a", "b", "c")) # shape unknown - y3 = squeeze(x3, "b") + y3 = x3.squeeze("b") x3_test = xr_arange_like(xtensor(dims=x3.dims, shape=(2, 1, 3))) fn3 = xr_function([x3], y3) xr_assert_allclose(fn3(x3_test), x3_test.squeeze("b")) - # Mixed static + symbolic shapes, where symbolic shape is 1 + # Mixed known + unknown shapes x4 = xtensor("x4", dims=("a", "b", "c"), shape=(None, 1, 3)) - y4 = squeeze(x4, "b") + y4 = x4.squeeze("b") x4_test = xr_arange_like(xtensor(dims=x4.dims, shape=(4, 1, 3))) fn4 = xr_function([x4], y4) xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) - """ - This test documents that we intentionally don't squeeze dimensions with symbolic shapes - (static_shape=None) even when they are 1 at runtime, while xarray does squeeze them. - """ - # Create a tensor with a symbolic dimension that will be 1 at runtime - x = xtensor("x", dims=("a", "b", "c")) # shape unknown - y = squeeze(x) # implicit dim=None should not squeeze symbolic dimensions - x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 1, 3))) - fn = xr_function([x], y) - res = fn(x_test) + # Test axis parameter + x5 = xtensor("x5", dims=("a", "b", "c"), shape=(2, 1, 3)) + y5 = x5.squeeze(axis=1) # squeeze dimension at index 1 (b) + fn5 = xr_function([x5], y5) + x5_test = xr_arange_like(x5) + xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=1)) + + # Test axis parameter with negative index + y5 = x5.squeeze(axis=-1) # squeeze dimension at index -2 (b) + fn5 = xr_function([x5], y5) + x5_test = xr_arange_like(x5) + xr_assert_allclose(fn5(x5_test), x5_test.squeeze(axis=-2)) + + # Test axis parameter with sequence of ints + y6 = x2.squeeze(axis=[1, 2]) + fn6 = xr_function([x2], y6) + x2_test = xr_arange_like(x2) + xr_assert_allclose(fn6(x2_test), x2_test.squeeze(axis=[1, 2])) - # Our implementation should not squeeze the symbolic dimension - assert "b" in res.dims - # While xarray would squeeze it - assert "b" not in x_test.squeeze().dims + # Test drop parameter warning + x7 = xtensor("x7", dims=("a", "b"), shape=(2, 1)) + with pytest.warns( + UserWarning, match="drop parameter has no effect in pytensor.xtensor" + ): + y7 = x7.squeeze("b", drop=True) # squeeze and drop coordinate + fn7 = xr_function([x7], y7) + x7_test = xr_arange_like(x7) + xr_assert_allclose(fn7(x7_test), x7_test.squeeze("b", drop=True)) def test_squeeze_errors(): @@ -356,15 +340,15 @@ def test_squeeze_errors(): # Non-existent dimension x1 = xtensor("x1", dims=("city", "country"), shape=(3, 1)) with pytest.raises(ValueError, match="Dimension .* not found"): - squeeze(x1, "time") + x1.squeeze("time") # Dimension size > 1 with pytest.raises(ValueError, match="has static size .* not 1"): - squeeze(x1, "city") + x1.squeeze("city") # Symbolic shape: dim is not 1 at runtime → should raise x2 = xtensor("x2", dims=("a", "b", "c")) # shape unknown - y2 = squeeze(x2, "b") + y2 = x2.squeeze("b") x2_test = xr_arange_like(xtensor(dims=x2.dims, shape=(2, 2, 3))) fn2 = xr_function([x2], y2) with pytest.raises(Exception): @@ -471,11 +455,6 @@ def test_expand_dims_errors(): with pytest.raises(ValueError, match="already exists"): y.expand_dims("new") - # Find out what xarray does with a numpy array as dim - # x_test = xr_arange_like(x) - # x_test.expand_dims(np.array([1, 2])) - # TypeError: unhashable type: 'numpy.ndarray' - # Test with a numpy array as dim (not supported) with pytest.raises(TypeError, match="unhashable type"): y.expand_dims(np.array([1, 2]))