|
7 | 7 | from xesmf.data import wave_smooth |
8 | 8 | from xesmf.util import grid_global |
9 | 9 | from xclim.sdba.adjustment import QuantileDeltaMapping |
| 10 | + |
10 | 11 | from dodola.services import ( |
11 | 12 | prime_qplad_output_zarrstore, |
12 | 13 | prime_qdm_output_zarrstore, |
@@ -796,38 +797,51 @@ def test_apply_non_polar_dtr_ceiling(): |
796 | 797 |
|
797 | 798 | # case 1 : non polar regions, should be applied |
798 | 799 | # Make some fake dtr data |
799 | | - n = 10 |
800 | | - ceiling = 5.0 |
801 | | - ts = np.linspace(0.0, 10, num=n) |
802 | | - ds_dtr = _datafactory(ts, start_time="1950-01-01") |
| 800 | + ceiling = 0.5 |
| 801 | + x = np.random.rand(2, 361) |
| 802 | + lon = np.arange(-0.5, 0.5, 0.5) |
| 803 | + lat = np.arange(-90, 90.5, 0.5) |
| 804 | + ds_dtr = xr.Dataset( |
| 805 | + {"fakevariable": xr.DataArray(x, {"lon": lon, "lat": lat}, ["lon", "lat"])} |
| 806 | + ) |
| 807 | + |
803 | 808 | in_url = "memory://test_correct_small_dtr/an/input/path.zarr" |
804 | 809 | out_url = "memory://test_correct_small_dtr/an/output/path.zarr" |
805 | 810 | repository.write(in_url, ds_dtr) |
806 | 811 |
|
807 | 812 | apply_non_polar_dtr_ceiling(in_url, out=out_url, ceiling=ceiling) |
808 | 813 | ds_dtr_corrected = repository.read(out_url) |
809 | 814 |
|
810 | | - # check values that should be capped |
811 | | - assert all( |
812 | | - x == ceiling |
813 | | - for x in ds_dtr_corrected["fakevariable"].where( |
814 | | - ds_dtr["fakevariable"] > ceiling, drop=True |
815 | | - ) |
| 815 | + are_not_capped = xr.ufuncs.logical_or( |
| 816 | + ds_dtr <= ceiling, |
| 817 | + xr.ufuncs.logical_or(ds_dtr["lat"] <= -60, ds_dtr["lat"] >= 60), |
816 | 818 | ) |
817 | 819 |
|
| 820 | + are_capped = xr.ufuncs.logical_not(are_not_capped) |
| 821 | + |
| 822 | + # check values that should be capped |
| 823 | + assert ( |
| 824 | + ds_dtr_corrected["fakevariable"].values[are_capped["fakevariable"].values] |
| 825 | + == ceiling |
| 826 | + ).all() |
| 827 | + |
818 | 828 | # check values that should not be capped |
819 | | - left = ds_dtr_corrected["fakevariable"].where( |
820 | | - ds_dtr["fakevariable"] <= ceiling, drop=True |
821 | | - ) |
822 | | - right = ds_dtr["fakevariable"].where(ds_dtr["fakevariable"] <= ceiling, drop=True) |
823 | | - xr.testing.assert_equal(left, right) |
| 829 | + corrected = ds_dtr_corrected["fakevariable"].values[ |
| 830 | + are_not_capped["fakevariable"].values |
| 831 | + ] |
| 832 | + not_corrected = ds_dtr["fakevariable"].values[are_not_capped["fakevariable"].values] |
| 833 | + np.testing.assert_equal(corrected, not_corrected) |
824 | 834 |
|
825 | | - # case 2 : polar regions, shouldn't be applied |
| 835 | + # case 2 : all polar regions, shouldn't be applied |
826 | 836 | # Make some fake dtr data |
827 | | - n = 10 |
828 | | - ceiling = 70.0 |
829 | | - ts = np.linspace(65.0, 75.0, num=n) |
830 | | - ds_dtr = _datafactory(ts, start_time="1950-01-01", lat=-61.0) |
| 837 | + ceiling = 0.5 |
| 838 | + x = np.random.rand(2, 58) |
| 839 | + lon = np.arange(-0.5, 0.5, 0.5) |
| 840 | + lat = np.arange(-90, -61, 0.5) |
| 841 | + ds_dtr = xr.Dataset( |
| 842 | + {"fakevariable": xr.DataArray(x, {"lon": lon, "lat": lat}, ["lon", "lat"])} |
| 843 | + ) |
| 844 | + |
831 | 845 | in_url = "memory://test_correct_small_dtr/an/input/path.zarr" |
832 | 846 | out_url = "memory://test_correct_small_dtr/an/output/path.zarr" |
833 | 847 | repository.write(in_url, ds_dtr) |
|
0 commit comments