Skip to content

Commit ae38884

Browse files
AllenDowneyricardoV94
authored andcommitted
Add full_like, ones_like, and zeros_like for XTensorVariable (#1514)
1 parent 815671d commit ae38884

File tree

4 files changed

+271
-3
lines changed

4 files changed

+271
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg, random
55
from pytensor.xtensor.math import dot
6-
from pytensor.xtensor.shape import broadcast, concat
6+
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
77
from pytensor.xtensor.type import (
88
as_xtensor,
99
xtensor,

pytensor/xtensor/shape.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.tensor.type import integer_dtypes
1414
from pytensor.tensor.utils import get_static_shape_from_size_variables
1515
from pytensor.xtensor.basic import XOp
16+
from pytensor.xtensor.math import cast, second
1617
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
1718
from pytensor.xtensor.vectorization import combine_dims_and_shape
1819

@@ -565,3 +566,100 @@ def broadcast(
565566
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
566567
# xarray broadcast always returns a tuple, even if there's only one tensor
567568
return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore
569+
570+
571+
def full_like(x, fill_value, dtype=None):
572+
"""Create a new XTensorVariable with the same shape and dimensions, filled with a specified value.
573+
574+
Parameters
575+
----------
576+
x : XTensorVariable
577+
The tensor to fill.
578+
fill_value : scalar or XTensorVariable
579+
The value to fill the new tensor with.
580+
dtype : str or np.dtype, optional
581+
The data type of the new tensor. If None, uses the dtype of the input tensor.
582+
583+
Returns
584+
-------
585+
XTensorVariable
586+
A new tensor with the same shape and dimensions as self, filled with fill_value.
587+
588+
Examples
589+
--------
590+
>>> from pytensor.xtensor import xtensor, full_like
591+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
592+
>>> y = full_like(x, 5.0)
593+
>>> assert y.dims == ("a", "b")
594+
>>> assert y.type.shape == (2, 3)
595+
"""
596+
x = as_xtensor(x)
597+
fill_value = as_xtensor(fill_value)
598+
599+
# Check that fill_value is a scalar (ndim=0)
600+
if fill_value.type.ndim != 0:
601+
raise ValueError(
602+
f"fill_value must be a scalar, got ndim={fill_value.type.ndim}"
603+
)
604+
605+
# Handle dtype conversion
606+
if dtype is not None:
607+
# If dtype is specified, cast the fill_value to that dtype
608+
fill_value = cast(fill_value, dtype)
609+
else:
610+
# If dtype is None, cast the fill_value to the input tensor's dtype
611+
# This matches xarray's behavior where it preserves the original dtype
612+
fill_value = cast(fill_value, x.type.dtype)
613+
614+
# Use the xtensor second function
615+
return second(x, fill_value)
616+
617+
618+
def ones_like(x, dtype=None):
619+
"""Create a new XTensorVariable with the same shape and dimensions, filled with ones.
620+
621+
Parameters
622+
----------
623+
x : XTensorVariable
624+
The tensor to fill.
625+
dtype : str or np.dtype, optional
626+
The data type of the new tensor. If None, uses the dtype of the input tensor.
627+
628+
Returns:
629+
XTensorVariable
630+
A new tensor with the same shape and dimensions as self, filled with ones.
631+
632+
Examples
633+
--------
634+
>>> from pytensor.xtensor import xtensor, full_like
635+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
636+
>>> y = ones_like(x)
637+
>>> assert y.dims == ("a", "b")
638+
>>> assert y.type.shape == (2, 3)
639+
"""
640+
return full_like(x, 1.0, dtype=dtype)
641+
642+
643+
def zeros_like(x, dtype=None):
644+
"""Create a new XTensorVariable with the same shape and dimensions, filled with zeros.
645+
646+
Parameters
647+
----------
648+
x : XTensorVariable
649+
The tensor to fill.
650+
dtype : str or np.dtype, optional
651+
The data type of the new tensor. If None, uses the dtype of the input tensor.
652+
653+
Returns:
654+
XTensorVariable
655+
A new tensor with the same shape and dimensions as self, filled with zeros.
656+
657+
Examples
658+
--------
659+
>>> from pytensor.xtensor import xtensor, full_like
660+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
661+
>>> y = zeros_like(x)
662+
>>> assert y.dims == ("a", "b")
663+
>>> assert y.type.shape == (2, 3)
664+
"""
665+
return full_like(x, 0.0, dtype=dtype)

tests/xtensor/test_shape.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
from xarray import DataArray
1212
from xarray import broadcast as xr_broadcast
1313
from xarray import concat as xr_concat
14+
from xarray import full_like as xr_full_like
15+
from xarray import ones_like as xr_ones_like
16+
from xarray import zeros_like as xr_zeros_like
1417

1518
from pytensor.tensor import scalar
1619
from pytensor.xtensor.shape import (
1720
broadcast,
1821
concat,
22+
full_like,
23+
ones_like,
1924
stack,
2025
unstack,
26+
zeros_like,
2127
)
2228
from pytensor.xtensor.type import xtensor
2329
from tests.xtensor.util import (
@@ -633,3 +639,148 @@ def test_broadcast_like(self, exclude):
633639
]
634640
for res, expected_res in zip(results, expected_results, strict=True):
635641
xr_assert_allclose(res, expected_res)
642+
643+
644+
def test_full_like():
645+
"""Test full_like function, comparing with xarray's full_like."""
646+
647+
# Basic functionality with scalar fill_value
648+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
649+
x_test = xr_arange_like(x)
650+
651+
y1 = full_like(x, 5.0)
652+
fn1 = xr_function([x], y1)
653+
result1 = fn1(x_test)
654+
expected1 = xr_full_like(x_test, 5.0)
655+
xr_assert_allclose(result1, expected1, check_dtype=True)
656+
657+
# Other dtypes
658+
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
659+
x_3d_test = xr_arange_like(x_3d)
660+
661+
y7 = full_like(x_3d, -1.0)
662+
fn7 = xr_function([x_3d], y7)
663+
result7 = fn7(x_3d_test)
664+
expected7 = xr_full_like(x_3d_test, -1.0)
665+
xr_assert_allclose(result7, expected7, check_dtype=True)
666+
667+
# Integer dtype
668+
y3 = full_like(x, 5.0, dtype="int32")
669+
fn3 = xr_function([x], y3)
670+
result3 = fn3(x_test)
671+
expected3 = xr_full_like(x_test, 5.0, dtype="int32")
672+
xr_assert_allclose(result3, expected3, check_dtype=True)
673+
674+
# Different fill_value types
675+
y4 = full_like(x, np.array(3.14))
676+
fn4 = xr_function([x], y4)
677+
result4 = fn4(x_test)
678+
expected4 = xr_full_like(x_test, 3.14)
679+
xr_assert_allclose(result4, expected4, check_dtype=True)
680+
681+
# Integer input with float fill_value
682+
x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32")
683+
x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b"))
684+
685+
y5 = full_like(x_int, 2.5)
686+
fn5 = xr_function([x_int], y5)
687+
result5 = fn5(x_int_test)
688+
expected5 = xr_full_like(x_int_test, 2.5)
689+
xr_assert_allclose(result5, expected5, check_dtype=True)
690+
691+
# Symbolic shapes
692+
x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3))
693+
x_sym_test = DataArray(
694+
np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b")
695+
)
696+
697+
y6 = full_like(x_sym, 7.0)
698+
fn6 = xr_function([x_sym], y6)
699+
result6 = fn6(x_sym_test)
700+
expected6 = xr_full_like(x_sym_test, 7.0)
701+
xr_assert_allclose(result6, expected6, check_dtype=True)
702+
703+
# Boolean dtype
704+
x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool")
705+
x_bool_test = DataArray(
706+
np.array([[True, False, True], [False, True, False]]), dims=("a", "b")
707+
)
708+
709+
y8 = full_like(x_bool, True)
710+
fn8 = xr_function([x_bool], y8)
711+
result8 = fn8(x_bool_test)
712+
expected8 = xr_full_like(x_bool_test, True)
713+
xr_assert_allclose(result8, expected8, check_dtype=True)
714+
715+
# Complex dtype
716+
x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64")
717+
x_complex_test = DataArray(
718+
np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b")
719+
)
720+
721+
y9 = full_like(x_complex, 1 + 2j)
722+
fn9 = xr_function([x_complex], y9)
723+
result9 = fn9(x_complex_test)
724+
expected9 = xr_full_like(x_complex_test, 1 + 2j)
725+
xr_assert_allclose(result9, expected9, check_dtype=True)
726+
727+
# Symbolic fill value
728+
x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64")
729+
fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64")
730+
x_sym_fill_test = xr_arange_like(x_sym_fill)
731+
fill_val_test = DataArray(3.14, dims=())
732+
733+
y10 = full_like(x_sym_fill, fill_val)
734+
fn10 = xr_function([x_sym_fill, fill_val], y10)
735+
result10 = fn10(x_sym_fill_test, fill_val_test)
736+
expected10 = xr_full_like(x_sym_fill_test, 3.14)
737+
xr_assert_allclose(result10, expected10, check_dtype=True)
738+
739+
# Test dtype conversion to bool when neither input nor fill_value are bool
740+
x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64")
741+
x_float_test = xr_arange_like(x_float)
742+
743+
y11 = full_like(x_float, 5.0, dtype="bool")
744+
fn11 = xr_function([x_float], y11)
745+
result11 = fn11(x_float_test)
746+
expected11 = xr_full_like(x_float_test, 5.0, dtype="bool")
747+
xr_assert_allclose(result11, expected11, check_dtype=True)
748+
749+
# Verify the result is actually boolean
750+
assert result11.dtype == "bool"
751+
assert expected11.dtype == "bool"
752+
753+
754+
def test_full_like_errors():
755+
"""Test full_like function errors."""
756+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
757+
x_test = xr_arange_like(x)
758+
759+
with pytest.raises(ValueError, match="fill_value must be a scalar"):
760+
full_like(x, x_test)
761+
762+
763+
def test_ones_like():
764+
"""Test ones_like function, comparing with xarray's ones_like."""
765+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
766+
x_test = xr_arange_like(x)
767+
768+
y1 = ones_like(x)
769+
fn1 = xr_function([x], y1)
770+
result1 = fn1(x_test)
771+
expected1 = xr_ones_like(x_test)
772+
xr_assert_allclose(result1, expected1)
773+
assert result1.dtype == expected1.dtype
774+
775+
776+
def test_zeros_like():
777+
"""Test zeros_like function, comparing with xarray's zeros_like."""
778+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
779+
x_test = xr_arange_like(x)
780+
781+
y1 = zeros_like(x)
782+
fn1 = xr_function([x], y1)
783+
result1 = fn1(x_test)
784+
expected1 = xr_zeros_like(x_test)
785+
xr_assert_allclose(result1, expected1)
786+
assert result1.dtype == expected1.dtype

tests/xtensor/util.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,30 @@ def xfn(*xr_inputs):
3737
return xfn
3838

3939

40-
def xr_assert_allclose(x, y, *args, **kwargs):
41-
# Assert that two xarray DataArrays are close, ignoring coordinates
40+
def xr_assert_allclose(x, y, check_dtype=False, *args, **kwargs):
41+
"""Assert that two xarray DataArrays are close, ignoring coordinates.
42+
43+
Mostly a wrapper around xarray.testing.assert_allclose,
44+
but with the option to check the dtype.
45+
46+
Parameters
47+
----------
48+
x : xarray.DataArray
49+
The first xarray DataArray to compare.
50+
y : xarray.DataArray
51+
The second xarray DataArray to compare.
52+
check_dtype : bool, optional
53+
If True, check that the dtype of the two DataArrays is the same.
54+
*args :
55+
Additional arguments to pass to xarray.testing.assert_allclose.
56+
**kwargs :
57+
Additional keyword arguments to pass to xarray.testing.assert_allclose.
58+
"""
4259
x = x.drop_vars(x.coords)
4360
y = y.drop_vars(y.coords)
4461
assert_allclose(x, y, *args, **kwargs)
62+
if check_dtype:
63+
assert x.dtype == y.dtype
4564

4665

4766
def xr_arange_like(x):

0 commit comments

Comments
 (0)