Skip to content

Commit 6e9eb50

Browse files
authored
Merge pull request #166 from emileten/correct-DTR-ceiling-boolean-expression
corrected DTR ceiling boolean expression and tests accordingly
2 parents 8df5544 + 257d817 commit 6e9eb50

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

dodola/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,10 @@ def non_polar_dtr_ceiling(ds, ceiling):
590590
"""
591591

592592
ds_corrected = ds.where(
593-
(ds <= ceiling) | (ds["lat"] <= -60 or ds["lat"] >= 60), ceiling
593+
xr.ufuncs.logical_or(
594+
ds <= ceiling, xr.ufuncs.logical_or(ds["lat"] <= -60, ds["lat"] >= 60)
595+
),
596+
ceiling,
594597
)
595598

596599
return ds_corrected

dodola/tests/test_services.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from xesmf.data import wave_smooth
88
from xesmf.util import grid_global
99
from xclim.sdba.adjustment import QuantileDeltaMapping
10+
1011
from dodola.services import (
1112
prime_qplad_output_zarrstore,
1213
prime_qdm_output_zarrstore,
@@ -796,38 +797,51 @@ def test_apply_non_polar_dtr_ceiling():
796797

797798
# case 1 : non polar regions, should be applied
798799
# Make some fake dtr data
799-
n = 10
800-
ceiling = 5.0
801-
ts = np.linspace(0.0, 10, num=n)
802-
ds_dtr = _datafactory(ts, start_time="1950-01-01")
800+
ceiling = 0.5
801+
x = np.random.rand(2, 361)
802+
lon = np.arange(-0.5, 0.5, 0.5)
803+
lat = np.arange(-90, 90.5, 0.5)
804+
ds_dtr = xr.Dataset(
805+
{"fakevariable": xr.DataArray(x, {"lon": lon, "lat": lat}, ["lon", "lat"])}
806+
)
807+
803808
in_url = "memory://test_correct_small_dtr/an/input/path.zarr"
804809
out_url = "memory://test_correct_small_dtr/an/output/path.zarr"
805810
repository.write(in_url, ds_dtr)
806811

807812
apply_non_polar_dtr_ceiling(in_url, out=out_url, ceiling=ceiling)
808813
ds_dtr_corrected = repository.read(out_url)
809814

810-
# check values that should be capped
811-
assert all(
812-
x == ceiling
813-
for x in ds_dtr_corrected["fakevariable"].where(
814-
ds_dtr["fakevariable"] > ceiling, drop=True
815-
)
815+
are_not_capped = xr.ufuncs.logical_or(
816+
ds_dtr <= ceiling,
817+
xr.ufuncs.logical_or(ds_dtr["lat"] <= -60, ds_dtr["lat"] >= 60),
816818
)
817819

820+
are_capped = xr.ufuncs.logical_not(are_not_capped)
821+
822+
# check values that should be capped
823+
assert (
824+
ds_dtr_corrected["fakevariable"].values[are_capped["fakevariable"].values]
825+
== ceiling
826+
).all()
827+
818828
# check values that should not be capped
819-
left = ds_dtr_corrected["fakevariable"].where(
820-
ds_dtr["fakevariable"] <= ceiling, drop=True
821-
)
822-
right = ds_dtr["fakevariable"].where(ds_dtr["fakevariable"] <= ceiling, drop=True)
823-
xr.testing.assert_equal(left, right)
829+
corrected = ds_dtr_corrected["fakevariable"].values[
830+
are_not_capped["fakevariable"].values
831+
]
832+
not_corrected = ds_dtr["fakevariable"].values[are_not_capped["fakevariable"].values]
833+
np.testing.assert_equal(corrected, not_corrected)
824834

825-
# case 2 : polar regions, shouldn't be applied
835+
# case 2 : all polar regions, shouldn't be applied
826836
# Make some fake dtr data
827-
n = 10
828-
ceiling = 70.0
829-
ts = np.linspace(65.0, 75.0, num=n)
830-
ds_dtr = _datafactory(ts, start_time="1950-01-01", lat=-61.0)
837+
ceiling = 0.5
838+
x = np.random.rand(2, 58)
839+
lon = np.arange(-0.5, 0.5, 0.5)
840+
lat = np.arange(-90, -61, 0.5)
841+
ds_dtr = xr.Dataset(
842+
{"fakevariable": xr.DataArray(x, {"lon": lon, "lat": lat}, ["lon", "lat"])}
843+
)
844+
831845
in_url = "memory://test_correct_small_dtr/an/input/path.zarr"
832846
out_url = "memory://test_correct_small_dtr/an/output/path.zarr"
833847
repository.write(in_url, ds_dtr)

0 commit comments

Comments
 (0)