|
7 | 7 | import inspect |
8 | 8 |
|
9 | 9 | import numpy as np |
| 10 | +import xarray as xr |
10 | 11 | from xarray import DataArray |
11 | 12 |
|
12 | 13 | import pytensor.scalar as ps |
@@ -314,3 +315,109 @@ def test_dot_errors(): |
314 | 315 | # Doesn't fail until the rewrite |
315 | 316 | with pytest.raises(ValueError, match="not aligned"): |
316 | 317 | fn(x_test, y_test) |
| 318 | + |
| 319 | + |
| 320 | +def test_full_like(): |
| 321 | + """Test full_like function, comparing with xarray's full_like.""" |
| 322 | + |
| 323 | + # Basic functionality with scalar fill_value |
| 324 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 325 | + x_test = xr_arange_like(x) |
| 326 | + |
| 327 | + y1 = pxm.full_like(x, 5.0) |
| 328 | + fn1 = xr_function([x], y1) |
| 329 | + result1 = fn1(x_test) |
| 330 | + expected1 = xr.full_like(x_test, 5.0) |
| 331 | + xr_assert_allclose(result1, expected1) |
| 332 | + |
| 333 | + # Different dtypes |
| 334 | + y3 = pxm.full_like(x, 5.0, dtype="int32") |
| 335 | + fn3 = xr_function([x], y3) |
| 336 | + result3 = fn3(x_test) |
| 337 | + expected3 = xr.full_like(x_test, 5.0, dtype="int32") |
| 338 | + xr_assert_allclose(result3, expected3) |
| 339 | + |
| 340 | + # Different fill_value types |
| 341 | + y4 = pxm.full_like(x, np.array(3.14)) |
| 342 | + fn4 = xr_function([x], y4) |
| 343 | + result4 = fn4(x_test) |
| 344 | + expected4 = xr.full_like(x_test, 3.14) |
| 345 | + xr_assert_allclose(result4, expected4) |
| 346 | + |
| 347 | + # Integer input with float fill_value |
| 348 | + x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32") |
| 349 | + x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b")) |
| 350 | + |
| 351 | + y5 = pxm.full_like(x_int, 2.5) |
| 352 | + fn5 = xr_function([x_int], y5) |
| 353 | + result5 = fn5(x_int_test) |
| 354 | + expected5 = xr.full_like(x_int_test, 2.5) |
| 355 | + xr_assert_allclose(result5, expected5) |
| 356 | + |
| 357 | + # Symbolic shapes |
| 358 | + x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3)) |
| 359 | + x_sym_test = DataArray(np.arange(6).reshape(2, 3), dims=("a", "b")) |
| 360 | + |
| 361 | + y6 = pxm.full_like(x_sym, 7.0) |
| 362 | + fn6 = xr_function([x_sym], y6) |
| 363 | + result6 = fn6(x_sym_test) |
| 364 | + expected6 = xr.full_like(x_sym_test, 7.0) |
| 365 | + xr_assert_allclose(result6, expected6) |
| 366 | + |
| 367 | + # Higher dimensional tensor |
| 368 | + x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32") |
| 369 | + x_3d_test = xr_arange_like(x_3d) |
| 370 | + |
| 371 | + y7 = pxm.full_like(x_3d, -1.0) |
| 372 | + fn7 = xr_function([x_3d], y7) |
| 373 | + result7 = fn7(x_3d_test) |
| 374 | + expected7 = xr.full_like(x_3d_test, -1.0) |
| 375 | + xr_assert_allclose(result7, expected7) |
| 376 | + |
| 377 | + # Boolean dtype |
| 378 | + x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool") |
| 379 | + x_bool_test = DataArray( |
| 380 | + np.array([[True, False, True], [False, True, False]]), dims=("a", "b") |
| 381 | + ) |
| 382 | + |
| 383 | + y8 = pxm.full_like(x_bool, True) |
| 384 | + fn8 = xr_function([x_bool], y8) |
| 385 | + result8 = fn8(x_bool_test) |
| 386 | + expected8 = xr.full_like(x_bool_test, True) |
| 387 | + xr_assert_allclose(result8, expected8) |
| 388 | + |
| 389 | + # Complex dtype |
| 390 | + x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64") |
| 391 | + x_complex_test = DataArray( |
| 392 | + np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b") |
| 393 | + ) |
| 394 | + |
| 395 | + y9 = pxm.full_like(x_complex, 1 + 2j) |
| 396 | + fn9 = xr_function([x_complex], y9) |
| 397 | + result9 = fn9(x_complex_test) |
| 398 | + expected9 = xr.full_like(x_complex_test, 1 + 2j) |
| 399 | + xr_assert_allclose(result9, expected9) |
| 400 | + |
| 401 | + |
| 402 | +def test_ones_like(): |
| 403 | + """Test ones_like function, comparing with xarray's ones_like.""" |
| 404 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 405 | + x_test = xr_arange_like(x) |
| 406 | + |
| 407 | + y1 = pxm.ones_like(x) |
| 408 | + fn1 = xr_function([x], y1) |
| 409 | + result1 = fn1(x_test) |
| 410 | + expected1 = xr.ones_like(x_test) |
| 411 | + xr_assert_allclose(result1, expected1) |
| 412 | + |
| 413 | + |
| 414 | +def test_zeros_like(): |
| 415 | + """Test zeros_like function, comparing with xarray's zeros_like.""" |
| 416 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 417 | + x_test = xr_arange_like(x) |
| 418 | + |
| 419 | + y1 = pxm.zeros_like(x) |
| 420 | + fn1 = xr_function([x], y1) |
| 421 | + result1 = fn1(x_test) |
| 422 | + expected1 = xr.zeros_like(x_test) |
| 423 | + xr_assert_allclose(result1, expected1) |
0 commit comments