|
11 | 11 | from xarray import DataArray
|
12 | 12 | from xarray import broadcast as xr_broadcast
|
13 | 13 | from xarray import concat as xr_concat
|
| 14 | +from xarray import full_like as xr_full_like |
| 15 | +from xarray import ones_like as xr_ones_like |
| 16 | +from xarray import zeros_like as xr_zeros_like |
14 | 17 |
|
15 | 18 | from pytensor.tensor import scalar
|
16 | 19 | from pytensor.xtensor.shape import (
|
17 | 20 | broadcast,
|
18 | 21 | concat,
|
| 22 | + full_like, |
| 23 | + ones_like, |
19 | 24 | stack,
|
20 | 25 | unstack,
|
| 26 | + zeros_like, |
21 | 27 | )
|
22 | 28 | from pytensor.xtensor.type import xtensor
|
23 | 29 | from tests.xtensor.util import (
|
@@ -633,3 +639,148 @@ def test_broadcast_like(self, exclude):
|
633 | 639 | ]
|
634 | 640 | for res, expected_res in zip(results, expected_results, strict=True):
|
635 | 641 | xr_assert_allclose(res, expected_res)
|
| 642 | + |
| 643 | + |
| 644 | +def test_full_like(): |
| 645 | + """Test full_like function, comparing with xarray's full_like.""" |
| 646 | + |
| 647 | + # Basic functionality with scalar fill_value |
| 648 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 649 | + x_test = xr_arange_like(x) |
| 650 | + |
| 651 | + y1 = full_like(x, 5.0) |
| 652 | + fn1 = xr_function([x], y1) |
| 653 | + result1 = fn1(x_test) |
| 654 | + expected1 = xr_full_like(x_test, 5.0) |
| 655 | + xr_assert_allclose(result1, expected1, check_dtype=True) |
| 656 | + |
| 657 | + # Other dtypes |
| 658 | + x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32") |
| 659 | + x_3d_test = xr_arange_like(x_3d) |
| 660 | + |
| 661 | + y7 = full_like(x_3d, -1.0) |
| 662 | + fn7 = xr_function([x_3d], y7) |
| 663 | + result7 = fn7(x_3d_test) |
| 664 | + expected7 = xr_full_like(x_3d_test, -1.0) |
| 665 | + xr_assert_allclose(result7, expected7, check_dtype=True) |
| 666 | + |
| 667 | + # Integer dtype |
| 668 | + y3 = full_like(x, 5.0, dtype="int32") |
| 669 | + fn3 = xr_function([x], y3) |
| 670 | + result3 = fn3(x_test) |
| 671 | + expected3 = xr_full_like(x_test, 5.0, dtype="int32") |
| 672 | + xr_assert_allclose(result3, expected3, check_dtype=True) |
| 673 | + |
| 674 | + # Different fill_value types |
| 675 | + y4 = full_like(x, np.array(3.14)) |
| 676 | + fn4 = xr_function([x], y4) |
| 677 | + result4 = fn4(x_test) |
| 678 | + expected4 = xr_full_like(x_test, 3.14) |
| 679 | + xr_assert_allclose(result4, expected4, check_dtype=True) |
| 680 | + |
| 681 | + # Integer input with float fill_value |
| 682 | + x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32") |
| 683 | + x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b")) |
| 684 | + |
| 685 | + y5 = full_like(x_int, 2.5) |
| 686 | + fn5 = xr_function([x_int], y5) |
| 687 | + result5 = fn5(x_int_test) |
| 688 | + expected5 = xr_full_like(x_int_test, 2.5) |
| 689 | + xr_assert_allclose(result5, expected5, check_dtype=True) |
| 690 | + |
| 691 | + # Symbolic shapes |
| 692 | + x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3)) |
| 693 | + x_sym_test = DataArray( |
| 694 | + np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b") |
| 695 | + ) |
| 696 | + |
| 697 | + y6 = full_like(x_sym, 7.0) |
| 698 | + fn6 = xr_function([x_sym], y6) |
| 699 | + result6 = fn6(x_sym_test) |
| 700 | + expected6 = xr_full_like(x_sym_test, 7.0) |
| 701 | + xr_assert_allclose(result6, expected6, check_dtype=True) |
| 702 | + |
| 703 | + # Boolean dtype |
| 704 | + x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool") |
| 705 | + x_bool_test = DataArray( |
| 706 | + np.array([[True, False, True], [False, True, False]]), dims=("a", "b") |
| 707 | + ) |
| 708 | + |
| 709 | + y8 = full_like(x_bool, True) |
| 710 | + fn8 = xr_function([x_bool], y8) |
| 711 | + result8 = fn8(x_bool_test) |
| 712 | + expected8 = xr_full_like(x_bool_test, True) |
| 713 | + xr_assert_allclose(result8, expected8, check_dtype=True) |
| 714 | + |
| 715 | + # Complex dtype |
| 716 | + x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64") |
| 717 | + x_complex_test = DataArray( |
| 718 | + np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b") |
| 719 | + ) |
| 720 | + |
| 721 | + y9 = full_like(x_complex, 1 + 2j) |
| 722 | + fn9 = xr_function([x_complex], y9) |
| 723 | + result9 = fn9(x_complex_test) |
| 724 | + expected9 = xr_full_like(x_complex_test, 1 + 2j) |
| 725 | + xr_assert_allclose(result9, expected9, check_dtype=True) |
| 726 | + |
| 727 | + # Symbolic fill value |
| 728 | + x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 729 | + fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64") |
| 730 | + x_sym_fill_test = xr_arange_like(x_sym_fill) |
| 731 | + fill_val_test = DataArray(3.14, dims=()) |
| 732 | + |
| 733 | + y10 = full_like(x_sym_fill, fill_val) |
| 734 | + fn10 = xr_function([x_sym_fill, fill_val], y10) |
| 735 | + result10 = fn10(x_sym_fill_test, fill_val_test) |
| 736 | + expected10 = xr_full_like(x_sym_fill_test, 3.14) |
| 737 | + xr_assert_allclose(result10, expected10, check_dtype=True) |
| 738 | + |
| 739 | + # Test dtype conversion to bool when neither input nor fill_value are bool |
| 740 | + x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 741 | + x_float_test = xr_arange_like(x_float) |
| 742 | + |
| 743 | + y11 = full_like(x_float, 5.0, dtype="bool") |
| 744 | + fn11 = xr_function([x_float], y11) |
| 745 | + result11 = fn11(x_float_test) |
| 746 | + expected11 = xr_full_like(x_float_test, 5.0, dtype="bool") |
| 747 | + xr_assert_allclose(result11, expected11, check_dtype=True) |
| 748 | + |
| 749 | + # Verify the result is actually boolean |
| 750 | + assert result11.dtype == "bool" |
| 751 | + assert expected11.dtype == "bool" |
| 752 | + |
| 753 | + |
| 754 | +def test_full_like_errors(): |
| 755 | + """Test full_like function errors.""" |
| 756 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 757 | + x_test = xr_arange_like(x) |
| 758 | + |
| 759 | + with pytest.raises(ValueError, match="fill_value must be a scalar"): |
| 760 | + full_like(x, x_test) |
| 761 | + |
| 762 | + |
| 763 | +def test_ones_like(): |
| 764 | + """Test ones_like function, comparing with xarray's ones_like.""" |
| 765 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 766 | + x_test = xr_arange_like(x) |
| 767 | + |
| 768 | + y1 = ones_like(x) |
| 769 | + fn1 = xr_function([x], y1) |
| 770 | + result1 = fn1(x_test) |
| 771 | + expected1 = xr_ones_like(x_test) |
| 772 | + xr_assert_allclose(result1, expected1) |
| 773 | + assert result1.dtype == expected1.dtype |
| 774 | + |
| 775 | + |
| 776 | +def test_zeros_like(): |
| 777 | + """Test zeros_like function, comparing with xarray's zeros_like.""" |
| 778 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 779 | + x_test = xr_arange_like(x) |
| 780 | + |
| 781 | + y1 = zeros_like(x) |
| 782 | + fn1 = xr_function([x], y1) |
| 783 | + result1 = fn1(x_test) |
| 784 | + expected1 = xr_zeros_like(x_test) |
| 785 | + xr_assert_allclose(result1, expected1) |
| 786 | + assert result1.dtype == expected1.dtype |
0 commit comments