diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 29a9f5f996..c2917df685 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -3,7 +3,7 @@ import pytensor.xtensor.rewriting from pytensor.xtensor import linalg from pytensor.xtensor.math import dot -from pytensor.xtensor.shape import concat +from pytensor.xtensor.shape import concat, full_like, ones_like, zeros_like from pytensor.xtensor.type import ( as_xtensor, xtensor, diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 4868e6e4f7..5c4781fdc0 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -13,6 +13,7 @@ from pytensor.tensor.type import integer_dtypes from pytensor.tensor.utils import get_static_shape_from_size_variables from pytensor.xtensor.basic import XOp +from pytensor.xtensor.math import cast, second from pytensor.xtensor.type import as_xtensor, xtensor @@ -498,3 +499,100 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa x = Transpose(dims=tuple(target_dims))(x) return x + + +def full_like(x, fill_value, dtype=None): + """Create a new XTensorVariable with the same shape and dimensions, filled with a specified value. + + Parameters + ---------- + x : XTensorVariable + The tensor to fill. + fill_value : scalar or XTensorVariable + The value to fill the new tensor with. + dtype : str or np.dtype, optional + The data type of the new tensor. If None, uses the dtype of the input tensor. + + Returns + ------- + XTensorVariable + A new tensor with the same shape and dimensions as self, filled with fill_value. + + Examples + -------- + >>> from pytensor.xtensor import xtensor, full_like + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = full_like(x, 5.0) + >>> assert y.dims == ("a", "b") + >>> assert y.type.shape == (2, 3) + """ + x = as_xtensor(x) + fill_value = as_xtensor(fill_value) + + # Check that fill_value is a scalar (ndim=0) + if fill_value.type.ndim != 0: + raise ValueError( + f"fill_value must be a scalar, got ndim={fill_value.type.ndim}" + ) + + # Handle dtype conversion + if dtype is not None: + # If dtype is specified, cast the fill_value to that dtype + fill_value = cast(fill_value, dtype) + else: + # If dtype is None, cast the fill_value to the input tensor's dtype + # This matches xarray's behavior where it preserves the original dtype + fill_value = cast(fill_value, x.type.dtype) + + # Use the xtensor second function + return second(x, fill_value) + + +def ones_like(x, dtype=None): + """Create a new XTensorVariable with the same shape and dimensions, filled with ones. + + Parameters + ---------- + x : XTensorVariable + The tensor to fill. + dtype : str or np.dtype, optional + The data type of the new tensor. If None, uses the dtype of the input tensor. + + Returns: + XTensorVariable + A new tensor with the same shape and dimensions as self, filled with ones. + + Examples + -------- + >>> from pytensor.xtensor import xtensor, full_like + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = ones_like(x) + >>> assert y.dims == ("a", "b") + >>> assert y.type.shape == (2, 3) + """ + return full_like(x, 1.0, dtype=dtype) + + +def zeros_like(x, dtype=None): + """Create a new XTensorVariable with the same shape and dimensions, filled with zeros. + + Parameters + ---------- + x : XTensorVariable + The tensor to fill. + dtype : str or np.dtype, optional + The data type of the new tensor. If None, uses the dtype of the input tensor. + + Returns: + XTensorVariable + A new tensor with the same shape and dimensions as self, filled with zeros. + + Examples + -------- + >>> from pytensor.xtensor import xtensor, full_like + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = zeros_like(x) + >>> assert y.dims == ("a", "b") + >>> assert y.type.shape == (2, 3) + """ + return full_like(x, 0.0, dtype=dtype) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index da2c5f1913..f7304b40be 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,9 +8,10 @@ from itertools import chain, combinations import numpy as np +import xarray as xr from xarray import DataArray -from xarray import concat as xr_concat +import pytensor.xtensor as px from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, @@ -226,7 +227,7 @@ def test_concat(dim): x3_test = xr_random_like(x3, rng) res = fn(x1_test, x2_test, x3_test) - expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim) + expected_res = xr.concat([x1_test, x2_test, x3_test], dim=dim) xr_assert_allclose(res, expected_res) @@ -248,7 +249,7 @@ def test_concat_with_broadcast(dim): x3_test = xr_random_like(x3, rng) x4_test = xr_random_like(x4, rng) res = fn(x1_test, x2_test, x3_test, x4_test) - expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim) + expected_res = xr.concat([x1_test, x2_test, x3_test, x4_test], dim=dim) xr_assert_allclose(res, expected_res) @@ -263,7 +264,7 @@ def test_concat_scalar(): x1_test = xr_random_like(x1) x2_test = xr_random_like(x2) res = fn(x1_test, x2_test) - expected_res = xr_concat([x1_test, x2_test], dim="new_dim") + expected_res = xr.concat([x1_test, x2_test], dim="new_dim") xr_assert_allclose(res, expected_res) @@ -466,3 +467,148 @@ def test_expand_dims_errors(): # Test with a numpy array as dim (not supported) with pytest.raises(TypeError, match="unhashable type"): y.expand_dims(np.array([1, 2])) + + +def test_full_like(): + """Test full_like function, comparing with xarray's full_like.""" + + # Basic functionality with scalar fill_value + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") + x_test = xr_arange_like(x) + + y1 = px.full_like(x, 5.0) + fn1 = xr_function([x], y1) + result1 = fn1(x_test) + expected1 = xr.full_like(x_test, 5.0) + xr_assert_allclose(result1, expected1, check_dtype=True) + + # Other dtypes + x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32") + x_3d_test = xr_arange_like(x_3d) + + y7 = px.full_like(x_3d, -1.0) + fn7 = xr_function([x_3d], y7) + result7 = fn7(x_3d_test) + expected7 = xr.full_like(x_3d_test, -1.0) + xr_assert_allclose(result7, expected7, check_dtype=True) + + # Integer dtype + y3 = px.full_like(x, 5.0, dtype="int32") + fn3 = xr_function([x], y3) + result3 = fn3(x_test) + expected3 = xr.full_like(x_test, 5.0, dtype="int32") + xr_assert_allclose(result3, expected3, check_dtype=True) + + # Different fill_value types + y4 = px.full_like(x, np.array(3.14)) + fn4 = xr_function([x], y4) + result4 = fn4(x_test) + expected4 = xr.full_like(x_test, 3.14) + xr_assert_allclose(result4, expected4, check_dtype=True) + + # Integer input with float fill_value + x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32") + x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b")) + + y5 = px.full_like(x_int, 2.5) + fn5 = xr_function([x_int], y5) + result5 = fn5(x_int_test) + expected5 = xr.full_like(x_int_test, 2.5) + xr_assert_allclose(result5, expected5, check_dtype=True) + + # Symbolic shapes + x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3)) + x_sym_test = DataArray( + np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b") + ) + + y6 = px.full_like(x_sym, 7.0) + fn6 = xr_function([x_sym], y6) + result6 = fn6(x_sym_test) + expected6 = xr.full_like(x_sym_test, 7.0) + xr_assert_allclose(result6, expected6, check_dtype=True) + + # Boolean dtype + x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool") + x_bool_test = DataArray( + np.array([[True, False, True], [False, True, False]]), dims=("a", "b") + ) + + y8 = px.full_like(x_bool, True) + fn8 = xr_function([x_bool], y8) + result8 = fn8(x_bool_test) + expected8 = xr.full_like(x_bool_test, True) + xr_assert_allclose(result8, expected8, check_dtype=True) + + # Complex dtype + x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64") + x_complex_test = DataArray( + np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b") + ) + + y9 = px.full_like(x_complex, 1 + 2j) + fn9 = xr_function([x_complex], y9) + result9 = fn9(x_complex_test) + expected9 = xr.full_like(x_complex_test, 1 + 2j) + xr_assert_allclose(result9, expected9, check_dtype=True) + + # Symbolic fill value + x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64") + fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64") + x_sym_fill_test = xr_arange_like(x_sym_fill) + fill_val_test = DataArray(3.14, dims=()) + + y10 = px.full_like(x_sym_fill, fill_val) + fn10 = xr_function([x_sym_fill, fill_val], y10) + result10 = fn10(x_sym_fill_test, fill_val_test) + expected10 = xr.full_like(x_sym_fill_test, 3.14) + xr_assert_allclose(result10, expected10, check_dtype=True) + + # Test dtype conversion to bool when neither input nor fill_value are bool + x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64") + x_float_test = xr_arange_like(x_float) + + y11 = px.full_like(x_float, 5.0, dtype="bool") + fn11 = xr_function([x_float], y11) + result11 = fn11(x_float_test) + expected11 = xr.full_like(x_float_test, 5.0, dtype="bool") + xr_assert_allclose(result11, expected11, check_dtype=True) + + # Verify the result is actually boolean + assert result11.dtype == "bool" + assert expected11.dtype == "bool" + + +def test_full_like_errors(): + """Test full_like function errors.""" + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") + x_test = xr_arange_like(x) + + with pytest.raises(ValueError, match="fill_value must be a scalar"): + px.full_like(x, x_test) + + +def test_ones_like(): + """Test ones_like function, comparing with xarray's ones_like.""" + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") + x_test = xr_arange_like(x) + + y1 = px.ones_like(x) + fn1 = xr_function([x], y1) + result1 = fn1(x_test) + expected1 = xr.ones_like(x_test) + xr_assert_allclose(result1, expected1) + assert result1.dtype == expected1.dtype + + +def test_zeros_like(): + """Test zeros_like function, comparing with xarray's zeros_like.""" + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") + x_test = xr_arange_like(x) + + y1 = px.zeros_like(x) + fn1 = xr_function([x], y1) + result1 = fn1(x_test) + expected1 = xr.zeros_like(x_test) + xr_assert_allclose(result1, expected1) + assert result1.dtype == expected1.dtype diff --git a/tests/xtensor/util.py b/tests/xtensor/util.py index 81dc98a75c..1d76afe0ea 100644 --- a/tests/xtensor/util.py +++ b/tests/xtensor/util.py @@ -37,11 +37,30 @@ def xfn(*xr_inputs): return xfn -def xr_assert_allclose(x, y, *args, **kwargs): - # Assert that two xarray DataArrays are close, ignoring coordinates +def xr_assert_allclose(x, y, check_dtype=False, *args, **kwargs): + """Assert that two xarray DataArrays are close, ignoring coordinates. + + Mostly a wrapper around xarray.testing.assert_allclose, + but with the option to check the dtype. + + Parameters + ---------- + x : xarray.DataArray + The first xarray DataArray to compare. + y : xarray.DataArray + The second xarray DataArray to compare. + check_dtype : bool, optional + If True, check that the dtype of the two DataArrays is the same. + *args : + Additional arguments to pass to xarray.testing.assert_allclose. + **kwargs : + Additional keyword arguments to pass to xarray.testing.assert_allclose. + """ x = x.drop_vars(x.coords) y = y.drop_vars(y.coords) assert_allclose(x, y, *args, **kwargs) + if check_dtype: + assert x.dtype == y.dtype def xr_arange_like(x):