Skip to content

Commit f2cf8a4

Browse files
committed
First pass at full_like, ones_like, and zeros_like
1 parent 41d9be4 commit f2cf8a4

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

pytensor/xtensor/math.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,96 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
250250
result = XDot(dims=tuple(dim_set))(x, y)
251251

252252
return result
253+
254+
255+
def full_like(x, fill_value, dtype=None):
256+
"""Create a new XTensorVariable with the same shape and dimensions, filled with a specified value.
257+
258+
Parameters
259+
----------
260+
x : XTensorVariable
261+
The tensor to fill.
262+
fill_value : scalar or XTensorVariable
263+
The value to fill the new tensor with.
264+
dtype : str or np.dtype, optional
265+
The data type of the new tensor. If None, uses the dtype of the input tensor.
266+
267+
Returns
268+
-------
269+
XTensorVariable
270+
A new tensor with the same shape and dimensions as self, filled with fill_value.
271+
272+
Examples
273+
--------
274+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
275+
>>> y = full_like(x, 5.0)
276+
>>> y.dims
277+
('a', 'b')
278+
>>> y.shape
279+
(2, 3)
280+
"""
281+
x = as_xtensor(x)
282+
283+
# Handle dtype conversion
284+
if dtype is not None:
285+
# If dtype is specified, cast the fill_value to that dtype
286+
fill_value = cast(fill_value, dtype)
287+
else:
288+
# If dtype is None, cast the fill_value to the input tensor's dtype
289+
# This matches xarray's behavior where it preserves the original dtype
290+
fill_value = cast(fill_value, x.type.dtype)
291+
292+
# Use the xtensor second function
293+
return second(x, fill_value)
294+
295+
296+
def ones_like(x, dtype=None):
297+
"""Create a new XTensorVariable with the same shape and dimensions, filled with ones.
298+
299+
Parameters
300+
----------
301+
x : XTensorVariable
302+
The tensor to fill.
303+
dtype : str or np.dtype, optional
304+
The data type of the new tensor. If None, uses the dtype of the input tensor.
305+
306+
Returns:
307+
XTensorVariable
308+
A new tensor with the same shape and dimensions as self, filled with ones.
309+
310+
Examples
311+
--------
312+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
313+
>>> y = ones_like(x)
314+
>>> y.dims
315+
('a', 'b')
316+
>>> y.shape
317+
(2, 3)
318+
"""
319+
return full_like(x, 1.0, dtype=dtype)
320+
321+
322+
def zeros_like(x, dtype=None):
323+
"""Create a new XTensorVariable with the same shape and dimensions, filled with zeros.
324+
325+
Parameters
326+
----------
327+
x : XTensorVariable
328+
The tensor to fill.
329+
dtype : str or np.dtype, optional
330+
The data type of the new tensor. If None, uses the dtype of the input tensor.
331+
332+
Returns:
333+
XTensorVariable
334+
A new tensor with the same shape and dimensions as self, filled with zeros.
335+
336+
Examples
337+
--------
338+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
339+
>>> y = zeros_like(x)
340+
>>> y.dims
341+
('a', 'b')
342+
>>> y.shape
343+
(2, 3)
344+
"""
345+
return full_like(x, 0.0, dtype=dtype)

tests/xtensor/test_math.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88

99
import numpy as np
10+
import xarray as xr
1011
from xarray import DataArray
1112

1213
import pytensor.scalar as ps
@@ -314,3 +315,109 @@ def test_dot_errors():
314315
# Doesn't fail until the rewrite
315316
with pytest.raises(ValueError, match="not aligned"):
316317
fn(x_test, y_test)
318+
319+
320+
def test_full_like():
321+
"""Test full_like function, comparing with xarray's full_like."""
322+
323+
# Basic functionality with scalar fill_value
324+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
325+
x_test = xr_arange_like(x)
326+
327+
y1 = pxm.full_like(x, 5.0)
328+
fn1 = xr_function([x], y1)
329+
result1 = fn1(x_test)
330+
expected1 = xr.full_like(x_test, 5.0)
331+
xr_assert_allclose(result1, expected1)
332+
333+
# Different dtypes
334+
y3 = pxm.full_like(x, 5.0, dtype="int32")
335+
fn3 = xr_function([x], y3)
336+
result3 = fn3(x_test)
337+
expected3 = xr.full_like(x_test, 5.0, dtype="int32")
338+
xr_assert_allclose(result3, expected3)
339+
340+
# Different fill_value types
341+
y4 = pxm.full_like(x, np.array(3.14))
342+
fn4 = xr_function([x], y4)
343+
result4 = fn4(x_test)
344+
expected4 = xr.full_like(x_test, 3.14)
345+
xr_assert_allclose(result4, expected4)
346+
347+
# Integer input with float fill_value
348+
x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32")
349+
x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b"))
350+
351+
y5 = pxm.full_like(x_int, 2.5)
352+
fn5 = xr_function([x_int], y5)
353+
result5 = fn5(x_int_test)
354+
expected5 = xr.full_like(x_int_test, 2.5)
355+
xr_assert_allclose(result5, expected5)
356+
357+
# Symbolic shapes
358+
x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3))
359+
x_sym_test = DataArray(np.arange(6).reshape(2, 3), dims=("a", "b"))
360+
361+
y6 = pxm.full_like(x_sym, 7.0)
362+
fn6 = xr_function([x_sym], y6)
363+
result6 = fn6(x_sym_test)
364+
expected6 = xr.full_like(x_sym_test, 7.0)
365+
xr_assert_allclose(result6, expected6)
366+
367+
# Higher dimensional tensor
368+
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
369+
x_3d_test = xr_arange_like(x_3d)
370+
371+
y7 = pxm.full_like(x_3d, -1.0)
372+
fn7 = xr_function([x_3d], y7)
373+
result7 = fn7(x_3d_test)
374+
expected7 = xr.full_like(x_3d_test, -1.0)
375+
xr_assert_allclose(result7, expected7)
376+
377+
# Boolean dtype
378+
x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool")
379+
x_bool_test = DataArray(
380+
np.array([[True, False, True], [False, True, False]]), dims=("a", "b")
381+
)
382+
383+
y8 = pxm.full_like(x_bool, True)
384+
fn8 = xr_function([x_bool], y8)
385+
result8 = fn8(x_bool_test)
386+
expected8 = xr.full_like(x_bool_test, True)
387+
xr_assert_allclose(result8, expected8)
388+
389+
# Complex dtype
390+
x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64")
391+
x_complex_test = DataArray(
392+
np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b")
393+
)
394+
395+
y9 = pxm.full_like(x_complex, 1 + 2j)
396+
fn9 = xr_function([x_complex], y9)
397+
result9 = fn9(x_complex_test)
398+
expected9 = xr.full_like(x_complex_test, 1 + 2j)
399+
xr_assert_allclose(result9, expected9)
400+
401+
402+
def test_ones_like():
403+
"""Test ones_like function, comparing with xarray's ones_like."""
404+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
405+
x_test = xr_arange_like(x)
406+
407+
y1 = pxm.ones_like(x)
408+
fn1 = xr_function([x], y1)
409+
result1 = fn1(x_test)
410+
expected1 = xr.ones_like(x_test)
411+
xr_assert_allclose(result1, expected1)
412+
413+
414+
def test_zeros_like():
415+
"""Test zeros_like function, comparing with xarray's zeros_like."""
416+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
417+
x_test = xr_arange_like(x)
418+
419+
y1 = pxm.zeros_like(x)
420+
fn1 = xr_function([x], y1)
421+
result1 = fn1(x_test)
422+
expected1 = xr.zeros_like(x_test)
423+
xr_assert_allclose(result1, expected1)

0 commit comments

Comments
 (0)