Skip to content

Commit 5296a80

Browse files
authored
Add full_like, ones_like, and zeros_like for XTensorVariable (#1514)
1 parent 41d9be4 commit 5296a80

File tree

4 files changed

+270
-7
lines changed

4 files changed

+270
-7
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg
55
from pytensor.xtensor.math import dot
6-
from pytensor.xtensor.shape import concat
6+
from pytensor.xtensor.shape import concat, full_like, ones_like, zeros_like
77
from pytensor.xtensor.type import (
88
as_xtensor,
99
xtensor,

pytensor/xtensor/shape.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.tensor.type import integer_dtypes
1414
from pytensor.tensor.utils import get_static_shape_from_size_variables
1515
from pytensor.xtensor.basic import XOp
16+
from pytensor.xtensor.math import cast, second
1617
from pytensor.xtensor.type import as_xtensor, xtensor
1718

1819

@@ -498,3 +499,100 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
498499
x = Transpose(dims=tuple(target_dims))(x)
499500

500501
return x
502+
503+
504+
def full_like(x, fill_value, dtype=None):
505+
"""Create a new XTensorVariable with the same shape and dimensions, filled with a specified value.
506+
507+
Parameters
508+
----------
509+
x : XTensorVariable
510+
The tensor to fill.
511+
fill_value : scalar or XTensorVariable
512+
The value to fill the new tensor with.
513+
dtype : str or np.dtype, optional
514+
The data type of the new tensor. If None, uses the dtype of the input tensor.
515+
516+
Returns
517+
-------
518+
XTensorVariable
519+
A new tensor with the same shape and dimensions as self, filled with fill_value.
520+
521+
Examples
522+
--------
523+
>>> from pytensor.xtensor import xtensor, full_like
524+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
525+
>>> y = full_like(x, 5.0)
526+
>>> assert y.dims == ("a", "b")
527+
>>> assert y.type.shape == (2, 3)
528+
"""
529+
x = as_xtensor(x)
530+
fill_value = as_xtensor(fill_value)
531+
532+
# Check that fill_value is a scalar (ndim=0)
533+
if fill_value.type.ndim != 0:
534+
raise ValueError(
535+
f"fill_value must be a scalar, got ndim={fill_value.type.ndim}"
536+
)
537+
538+
# Handle dtype conversion
539+
if dtype is not None:
540+
# If dtype is specified, cast the fill_value to that dtype
541+
fill_value = cast(fill_value, dtype)
542+
else:
543+
# If dtype is None, cast the fill_value to the input tensor's dtype
544+
# This matches xarray's behavior where it preserves the original dtype
545+
fill_value = cast(fill_value, x.type.dtype)
546+
547+
# Use the xtensor second function
548+
return second(x, fill_value)
549+
550+
551+
def ones_like(x, dtype=None):
552+
"""Create a new XTensorVariable with the same shape and dimensions, filled with ones.
553+
554+
Parameters
555+
----------
556+
x : XTensorVariable
557+
The tensor to fill.
558+
dtype : str or np.dtype, optional
559+
The data type of the new tensor. If None, uses the dtype of the input tensor.
560+
561+
Returns:
562+
XTensorVariable
563+
A new tensor with the same shape and dimensions as self, filled with ones.
564+
565+
Examples
566+
--------
567+
>>> from pytensor.xtensor import xtensor, full_like
568+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
569+
>>> y = ones_like(x)
570+
>>> assert y.dims == ("a", "b")
571+
>>> assert y.type.shape == (2, 3)
572+
"""
573+
return full_like(x, 1.0, dtype=dtype)
574+
575+
576+
def zeros_like(x, dtype=None):
577+
"""Create a new XTensorVariable with the same shape and dimensions, filled with zeros.
578+
579+
Parameters
580+
----------
581+
x : XTensorVariable
582+
The tensor to fill.
583+
dtype : str or np.dtype, optional
584+
The data type of the new tensor. If None, uses the dtype of the input tensor.
585+
586+
Returns:
587+
XTensorVariable
588+
A new tensor with the same shape and dimensions as self, filled with zeros.
589+
590+
Examples
591+
--------
592+
>>> from pytensor.xtensor import xtensor, full_like
593+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
594+
>>> y = zeros_like(x)
595+
>>> assert y.dims == ("a", "b")
596+
>>> assert y.type.shape == (2, 3)
597+
"""
598+
return full_like(x, 0.0, dtype=dtype)

tests/xtensor/test_shape.py

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11+
import xarray as xr
1112
from xarray import DataArray
12-
from xarray import concat as xr_concat
1313

14+
import pytensor.xtensor as px
1415
from pytensor.tensor import scalar
1516
from pytensor.xtensor.shape import (
1617
concat,
@@ -226,7 +227,7 @@ def test_concat(dim):
226227
x3_test = xr_random_like(x3, rng)
227228

228229
res = fn(x1_test, x2_test, x3_test)
229-
expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim)
230+
expected_res = xr.concat([x1_test, x2_test, x3_test], dim=dim)
230231
xr_assert_allclose(res, expected_res)
231232

232233

@@ -248,7 +249,7 @@ def test_concat_with_broadcast(dim):
248249
x3_test = xr_random_like(x3, rng)
249250
x4_test = xr_random_like(x4, rng)
250251
res = fn(x1_test, x2_test, x3_test, x4_test)
251-
expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim)
252+
expected_res = xr.concat([x1_test, x2_test, x3_test, x4_test], dim=dim)
252253
xr_assert_allclose(res, expected_res)
253254

254255

@@ -263,7 +264,7 @@ def test_concat_scalar():
263264
x1_test = xr_random_like(x1)
264265
x2_test = xr_random_like(x2)
265266
res = fn(x1_test, x2_test)
266-
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
267+
expected_res = xr.concat([x1_test, x2_test], dim="new_dim")
267268
xr_assert_allclose(res, expected_res)
268269

269270

@@ -466,3 +467,148 @@ def test_expand_dims_errors():
466467
# Test with a numpy array as dim (not supported)
467468
with pytest.raises(TypeError, match="unhashable type"):
468469
y.expand_dims(np.array([1, 2]))
470+
471+
472+
def test_full_like():
473+
"""Test full_like function, comparing with xarray's full_like."""
474+
475+
# Basic functionality with scalar fill_value
476+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
477+
x_test = xr_arange_like(x)
478+
479+
y1 = px.full_like(x, 5.0)
480+
fn1 = xr_function([x], y1)
481+
result1 = fn1(x_test)
482+
expected1 = xr.full_like(x_test, 5.0)
483+
xr_assert_allclose(result1, expected1, check_dtype=True)
484+
485+
# Other dtypes
486+
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
487+
x_3d_test = xr_arange_like(x_3d)
488+
489+
y7 = px.full_like(x_3d, -1.0)
490+
fn7 = xr_function([x_3d], y7)
491+
result7 = fn7(x_3d_test)
492+
expected7 = xr.full_like(x_3d_test, -1.0)
493+
xr_assert_allclose(result7, expected7, check_dtype=True)
494+
495+
# Integer dtype
496+
y3 = px.full_like(x, 5.0, dtype="int32")
497+
fn3 = xr_function([x], y3)
498+
result3 = fn3(x_test)
499+
expected3 = xr.full_like(x_test, 5.0, dtype="int32")
500+
xr_assert_allclose(result3, expected3, check_dtype=True)
501+
502+
# Different fill_value types
503+
y4 = px.full_like(x, np.array(3.14))
504+
fn4 = xr_function([x], y4)
505+
result4 = fn4(x_test)
506+
expected4 = xr.full_like(x_test, 3.14)
507+
xr_assert_allclose(result4, expected4, check_dtype=True)
508+
509+
# Integer input with float fill_value
510+
x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32")
511+
x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b"))
512+
513+
y5 = px.full_like(x_int, 2.5)
514+
fn5 = xr_function([x_int], y5)
515+
result5 = fn5(x_int_test)
516+
expected5 = xr.full_like(x_int_test, 2.5)
517+
xr_assert_allclose(result5, expected5, check_dtype=True)
518+
519+
# Symbolic shapes
520+
x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3))
521+
x_sym_test = DataArray(
522+
np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b")
523+
)
524+
525+
y6 = px.full_like(x_sym, 7.0)
526+
fn6 = xr_function([x_sym], y6)
527+
result6 = fn6(x_sym_test)
528+
expected6 = xr.full_like(x_sym_test, 7.0)
529+
xr_assert_allclose(result6, expected6, check_dtype=True)
530+
531+
# Boolean dtype
532+
x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool")
533+
x_bool_test = DataArray(
534+
np.array([[True, False, True], [False, True, False]]), dims=("a", "b")
535+
)
536+
537+
y8 = px.full_like(x_bool, True)
538+
fn8 = xr_function([x_bool], y8)
539+
result8 = fn8(x_bool_test)
540+
expected8 = xr.full_like(x_bool_test, True)
541+
xr_assert_allclose(result8, expected8, check_dtype=True)
542+
543+
# Complex dtype
544+
x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64")
545+
x_complex_test = DataArray(
546+
np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b")
547+
)
548+
549+
y9 = px.full_like(x_complex, 1 + 2j)
550+
fn9 = xr_function([x_complex], y9)
551+
result9 = fn9(x_complex_test)
552+
expected9 = xr.full_like(x_complex_test, 1 + 2j)
553+
xr_assert_allclose(result9, expected9, check_dtype=True)
554+
555+
# Symbolic fill value
556+
x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64")
557+
fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64")
558+
x_sym_fill_test = xr_arange_like(x_sym_fill)
559+
fill_val_test = DataArray(3.14, dims=())
560+
561+
y10 = px.full_like(x_sym_fill, fill_val)
562+
fn10 = xr_function([x_sym_fill, fill_val], y10)
563+
result10 = fn10(x_sym_fill_test, fill_val_test)
564+
expected10 = xr.full_like(x_sym_fill_test, 3.14)
565+
xr_assert_allclose(result10, expected10, check_dtype=True)
566+
567+
# Test dtype conversion to bool when neither input nor fill_value are bool
568+
x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64")
569+
x_float_test = xr_arange_like(x_float)
570+
571+
y11 = px.full_like(x_float, 5.0, dtype="bool")
572+
fn11 = xr_function([x_float], y11)
573+
result11 = fn11(x_float_test)
574+
expected11 = xr.full_like(x_float_test, 5.0, dtype="bool")
575+
xr_assert_allclose(result11, expected11, check_dtype=True)
576+
577+
# Verify the result is actually boolean
578+
assert result11.dtype == "bool"
579+
assert expected11.dtype == "bool"
580+
581+
582+
def test_full_like_errors():
583+
"""Test full_like function errors."""
584+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
585+
x_test = xr_arange_like(x)
586+
587+
with pytest.raises(ValueError, match="fill_value must be a scalar"):
588+
px.full_like(x, x_test)
589+
590+
591+
def test_ones_like():
592+
"""Test ones_like function, comparing with xarray's ones_like."""
593+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
594+
x_test = xr_arange_like(x)
595+
596+
y1 = px.ones_like(x)
597+
fn1 = xr_function([x], y1)
598+
result1 = fn1(x_test)
599+
expected1 = xr.ones_like(x_test)
600+
xr_assert_allclose(result1, expected1)
601+
assert result1.dtype == expected1.dtype
602+
603+
604+
def test_zeros_like():
605+
"""Test zeros_like function, comparing with xarray's zeros_like."""
606+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
607+
x_test = xr_arange_like(x)
608+
609+
y1 = px.zeros_like(x)
610+
fn1 = xr_function([x], y1)
611+
result1 = fn1(x_test)
612+
expected1 = xr.zeros_like(x_test)
613+
xr_assert_allclose(result1, expected1)
614+
assert result1.dtype == expected1.dtype

tests/xtensor/util.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,30 @@ def xfn(*xr_inputs):
3737
return xfn
3838

3939

40-
def xr_assert_allclose(x, y, *args, **kwargs):
41-
# Assert that two xarray DataArrays are close, ignoring coordinates
40+
def xr_assert_allclose(x, y, check_dtype=False, *args, **kwargs):
41+
"""Assert that two xarray DataArrays are close, ignoring coordinates.
42+
43+
Mostly a wrapper around xarray.testing.assert_allclose,
44+
but with the option to check the dtype.
45+
46+
Parameters
47+
----------
48+
x : xarray.DataArray
49+
The first xarray DataArray to compare.
50+
y : xarray.DataArray
51+
The second xarray DataArray to compare.
52+
check_dtype : bool, optional
53+
If True, check that the dtype of the two DataArrays is the same.
54+
*args :
55+
Additional arguments to pass to xarray.testing.assert_allclose.
56+
**kwargs :
57+
Additional keyword arguments to pass to xarray.testing.assert_allclose.
58+
"""
4259
x = x.drop_vars(x.coords)
4360
y = y.drop_vars(y.coords)
4461
assert_allclose(x, y, *args, **kwargs)
62+
if check_dtype:
63+
assert x.dtype == y.dtype
4564

4665

4766
def xr_arange_like(x):

0 commit comments

Comments
 (0)