Skip to content

Commit 91465b8

Browse files
committed
Adding tests
1 parent f2cf8a4 commit 91465b8

File tree

4 files changed

+97
-31
lines changed

4 files changed

+97
-31
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg
5-
from pytensor.xtensor.math import dot
5+
from pytensor.xtensor.math import dot, full_like, ones_like, zeros_like
66
from pytensor.xtensor.shape import concat
77
from pytensor.xtensor.type import (
88
as_xtensor,

pytensor/xtensor/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,13 @@ def full_like(x, fill_value, dtype=None):
279279
(2, 3)
280280
"""
281281
x = as_xtensor(x)
282+
fill_value = as_xtensor(fill_value)
283+
284+
# Check that fill_value is a scalar (ndim=0)
285+
if fill_value.type.ndim != 0:
286+
raise ValueError(
287+
f"fill_value must be a scalar, got ndim={fill_value.type.ndim}"
288+
)
282289

283290
# Handle dtype conversion
284291
if dtype is not None:

tests/xtensor/test_math.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from xarray import DataArray
1212

1313
import pytensor.scalar as ps
14+
import pytensor.xtensor as px
1415
import pytensor.xtensor.math as pxm
1516
from pytensor import function
1617
from pytensor.scalar import ScalarOp
@@ -324,100 +325,139 @@ def test_full_like():
324325
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
325326
x_test = xr_arange_like(x)
326327

327-
y1 = pxm.full_like(x, 5.0)
328+
y1 = px.full_like(x, 5.0)
328329
fn1 = xr_function([x], y1)
329330
result1 = fn1(x_test)
330331
expected1 = xr.full_like(x_test, 5.0)
331-
xr_assert_allclose(result1, expected1)
332+
xr_assert_allclose(result1, expected1, check_dtype=True)
333+
334+
# Other dtypes
335+
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
336+
x_3d_test = xr_arange_like(x_3d)
332337

333-
# Different dtypes
334-
y3 = pxm.full_like(x, 5.0, dtype="int32")
338+
y7 = px.full_like(x_3d, -1.0)
339+
fn7 = xr_function([x_3d], y7)
340+
result7 = fn7(x_3d_test)
341+
expected7 = xr.full_like(x_3d_test, -1.0)
342+
xr_assert_allclose(result7, expected7, check_dtype=True)
343+
344+
# Integer dtype
345+
y3 = px.full_like(x, 5.0, dtype="int32")
335346
fn3 = xr_function([x], y3)
336347
result3 = fn3(x_test)
337348
expected3 = xr.full_like(x_test, 5.0, dtype="int32")
338-
xr_assert_allclose(result3, expected3)
349+
xr_assert_allclose(result3, expected3, check_dtype=True)
339350

340351
# Different fill_value types
341-
y4 = pxm.full_like(x, np.array(3.14))
352+
y4 = px.full_like(x, np.array(3.14))
342353
fn4 = xr_function([x], y4)
343354
result4 = fn4(x_test)
344355
expected4 = xr.full_like(x_test, 3.14)
345-
xr_assert_allclose(result4, expected4)
356+
xr_assert_allclose(result4, expected4, check_dtype=True)
346357

347358
# Integer input with float fill_value
348359
x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32")
349360
x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b"))
350361

351-
y5 = pxm.full_like(x_int, 2.5)
362+
y5 = px.full_like(x_int, 2.5)
352363
fn5 = xr_function([x_int], y5)
353364
result5 = fn5(x_int_test)
354365
expected5 = xr.full_like(x_int_test, 2.5)
355-
xr_assert_allclose(result5, expected5)
366+
xr_assert_allclose(result5, expected5, check_dtype=True)
356367

357368
# Symbolic shapes
358369
x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3))
359-
x_sym_test = DataArray(np.arange(6).reshape(2, 3), dims=("a", "b"))
370+
x_sym_test = DataArray(
371+
np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b")
372+
)
360373

361-
y6 = pxm.full_like(x_sym, 7.0)
374+
y6 = px.full_like(x_sym, 7.0)
362375
fn6 = xr_function([x_sym], y6)
363376
result6 = fn6(x_sym_test)
364377
expected6 = xr.full_like(x_sym_test, 7.0)
365-
xr_assert_allclose(result6, expected6)
366-
367-
# Higher dimensional tensor
368-
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
369-
x_3d_test = xr_arange_like(x_3d)
370-
371-
y7 = pxm.full_like(x_3d, -1.0)
372-
fn7 = xr_function([x_3d], y7)
373-
result7 = fn7(x_3d_test)
374-
expected7 = xr.full_like(x_3d_test, -1.0)
375-
xr_assert_allclose(result7, expected7)
378+
xr_assert_allclose(result6, expected6, check_dtype=True)
376379

377380
# Boolean dtype
378381
x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool")
379382
x_bool_test = DataArray(
380383
np.array([[True, False, True], [False, True, False]]), dims=("a", "b")
381384
)
382385

383-
y8 = pxm.full_like(x_bool, True)
386+
y8 = px.full_like(x_bool, True)
384387
fn8 = xr_function([x_bool], y8)
385388
result8 = fn8(x_bool_test)
386389
expected8 = xr.full_like(x_bool_test, True)
387-
xr_assert_allclose(result8, expected8)
390+
xr_assert_allclose(result8, expected8, check_dtype=True)
388391

389392
# Complex dtype
390393
x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64")
391394
x_complex_test = DataArray(
392395
np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b")
393396
)
394397

395-
y9 = pxm.full_like(x_complex, 1 + 2j)
398+
y9 = px.full_like(x_complex, 1 + 2j)
396399
fn9 = xr_function([x_complex], y9)
397400
result9 = fn9(x_complex_test)
398401
expected9 = xr.full_like(x_complex_test, 1 + 2j)
399-
xr_assert_allclose(result9, expected9)
402+
xr_assert_allclose(result9, expected9, check_dtype=True)
403+
404+
# Symbolic fill value
405+
x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64")
406+
fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64")
407+
x_sym_fill_test = xr_arange_like(x_sym_fill)
408+
fill_val_test = DataArray(3.14, dims=())
409+
410+
y10 = px.full_like(x_sym_fill, fill_val)
411+
fn10 = xr_function([x_sym_fill, fill_val], y10)
412+
result10 = fn10(x_sym_fill_test, fill_val_test)
413+
expected10 = xr.full_like(x_sym_fill_test, 3.14)
414+
xr_assert_allclose(result10, expected10, check_dtype=True)
415+
416+
# Test dtype conversion to bool when neither input nor fill_value are bool
417+
x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64")
418+
x_float_test = xr_arange_like(x_float)
419+
420+
y11 = px.full_like(x_float, 5.0, dtype="bool")
421+
fn11 = xr_function([x_float], y11)
422+
result11 = fn11(x_float_test)
423+
expected11 = xr.full_like(x_float_test, 5.0, dtype="bool")
424+
xr_assert_allclose(result11, expected11, check_dtype=True)
425+
426+
# Verify the result is actually boolean
427+
assert result11.dtype == "bool"
428+
assert expected11.dtype == "bool"
429+
430+
431+
def test_full_like_errors():
432+
"""Test full_like function errors."""
433+
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
434+
x_test = xr_arange_like(x)
435+
436+
with pytest.raises(ValueError, match="fill_value must be a scalar"):
437+
px.full_like(x, x_test)
400438

401439

402440
def test_ones_like():
403441
"""Test ones_like function, comparing with xarray's ones_like."""
404442
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
405443
x_test = xr_arange_like(x)
406444

407-
y1 = pxm.ones_like(x)
445+
y1 = px.ones_like(x)
408446
fn1 = xr_function([x], y1)
409447
result1 = fn1(x_test)
410448
expected1 = xr.ones_like(x_test)
411449
xr_assert_allclose(result1, expected1)
450+
assert result1.dtype == expected1.dtype
412451

413452

414453
def test_zeros_like():
415454
"""Test zeros_like function, comparing with xarray's zeros_like."""
416455
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
417456
x_test = xr_arange_like(x)
418457

419-
y1 = pxm.zeros_like(x)
458+
y1 = px.zeros_like(x)
420459
fn1 = xr_function([x], y1)
421460
result1 = fn1(x_test)
422461
expected1 = xr.zeros_like(x_test)
423462
xr_assert_allclose(result1, expected1)
463+
assert result1.dtype == expected1.dtype

tests/xtensor/util.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,30 @@ def xfn(*xr_inputs):
3737
return xfn
3838

3939

40-
def xr_assert_allclose(x, y, *args, **kwargs):
41-
# Assert that two xarray DataArrays are close, ignoring coordinates
40+
def xr_assert_allclose(x, y, check_dtype=False, *args, **kwargs):
41+
"""Assert that two xarray DataArrays are close, ignoring coordinates.
42+
43+
Mostly a wrapper around xarray.testing.assert_allclose,
44+
but with the option to check the dtype.
45+
46+
Parameters
47+
----------
48+
x : xarray.DataArray
49+
The first xarray DataArray to compare.
50+
y : xarray.DataArray
51+
The second xarray DataArray to compare.
52+
check_dtype : bool, optional
53+
If True, check that the dtype of the two DataArrays is the same.
54+
*args :
55+
Additional arguments to pass to xarray.testing.assert_allclose.
56+
**kwargs :
57+
Additional keyword arguments to pass to xarray.testing.assert_allclose.
58+
"""
4259
x = x.drop_vars(x.coords)
4360
y = y.drop_vars(y.coords)
4461
assert_allclose(x, y, *args, **kwargs)
62+
if check_dtype:
63+
assert x.dtype == y.dtype
4564

4665

4766
def xr_arange_like(x):

0 commit comments

Comments
 (0)