Skip to content

Commit cc64352

Browse files
add simple first test for hydrostats component (#11)
* add basic unit tests for hydrostats * fix: missing f-strings
1 parent a6283a2 commit cc64352

File tree

3 files changed

+195
-2
lines changed

3 files changed

+195
-2
lines changed

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ gribjump = [
3636
hyve-extract-timeseries = "hyve.cli:extractor_cli"
3737
hyve-hydrostats = "hyve.cli:stat_calc_cli"
3838

39+
[tool.pytest]
40+
filterwarnings = [
41+
# This warning is raised due to netcdf4 but likely is harmless. See
42+
# https://github.com/Unidata/netcdf4-python/issues/1354 for more context.
43+
"ignore:numpy.ndarray size changed"
44+
]
45+
3946
[tool.black]
4047
line-length = 88
4148
skip-string-normalization = false

src/hyve/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def find_main_var(ds, min_dim=2):
2424
"""
2525
variable_names = [k for k in ds.variables if len(ds.variables[k].dims) >= min_dim]
2626
if len(variable_names) > 1:
27-
raise ValueError("More than one variable of dimension >= {min_dim} in dataset.")
27+
raise ValueError(
28+
f"More than one variable of dimension >= {min_dim} in dataset."
29+
)
2830
elif len(variable_names) == 0:
29-
raise ValueError("No variable of dimension >= {min_dim} in dataset.")
31+
raise ValueError(f"No variable of dimension >= {min_dim} in dataset.")
3032
else:
3133
return variable_names[0]
3234

tests/test_hydrostats.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import numpy as np
2+
import pytest
3+
4+
from hyve.hydrostats.stat_calc import stat_calc
5+
6+
7+
def _wrap_lod(dicts: list[dict]) -> dict:
8+
"""Simple wrapper to turn a list of dicts into an earthkit-data source config."""
9+
return {"list-of-dicts": {"list_of_dicts": dicts}}
10+
11+
12+
@pytest.fixture
13+
def simulation_source_config():
14+
"""earthkit-data list-of-dicts source config for simulation data."""
15+
return _wrap_lod(
16+
[
17+
{"date": 20240101, "time": 0, "number": 1, "param": "dis", "values": [1.0]},
18+
{"date": 20240102, "time": 0, "number": 1, "param": "dis", "values": [2.0]},
19+
{"date": 20240101, "time": 0, "number": 2, "param": "dis", "values": [1.0]},
20+
{"date": 20240102, "time": 0, "number": 2, "param": "dis", "values": [2.0]},
21+
# Stations that only exist in sim should be ignored
22+
{"date": 20240101, "time": 0, "number": 3, "param": "dis", "values": [8.0]},
23+
{"date": 20240102, "time": 0, "number": 3, "param": "dis", "values": [9.0]},
24+
# Times that only exist in sim should be ignored
25+
{"date": 20240101, "time": 6, "number": 1, "param": "dis", "values": [7.0]},
26+
{"date": 20240101, "time": 6, "number": 2, "param": "dis", "values": [4.0]},
27+
{"date": 20240101, "time": 6, "number": 3, "param": "dis", "values": [5.0]},
28+
]
29+
)
30+
31+
32+
@pytest.fixture
33+
def observation_source_config():
34+
"""earthkit-data list-of-dicts source config for observation data."""
35+
return _wrap_lod(
36+
[
37+
{"date": 20240101, "time": 0, "number": 1, "param": "d", "values": [1.0]},
38+
{"date": 20240102, "time": 0, "number": 1, "param": "d", "values": [2.0]},
39+
{"date": 20240101, "time": 0, "number": 2, "param": "d", "values": [2.0]},
40+
{"date": 20240102, "time": 0, "number": 2, "param": "d", "values": [1.0]},
41+
# Stations that only exist in obs should be ignored
42+
{"date": 20240101, "time": 0, "number": 4, "param": "d", "values": [9.0]},
43+
{"date": 20240102, "time": 0, "number": 4, "param": "d", "values": [9.0]},
44+
# Times that only exist in obs should be ignored
45+
{"date": 20240102, "time": 6, "number": 1, "param": "d", "values": [1.0]},
46+
{"date": 20240102, "time": 6, "number": 2, "param": "d", "values": [5.0]},
47+
{"date": 20240102, "time": 6, "number": 4, "param": "d", "values": [9.0]},
48+
]
49+
)
50+
51+
52+
def test_stat_calc(simulation_source_config, observation_source_config):
53+
"""Test stat_calc with synthetic data.
54+
55+
Simulated and observed data each have two stations and two times in common.
56+
Station 1 has identical values in sim and obs, leading to MAE=0 and correlation=1.
57+
Station 2 has inverted values in sim and obs, leading to MAE=1 and correlation=-1.
58+
59+
The test checks that only the common stations and times are considered, and that
60+
the calculated statistics match the expected results.
61+
62+
Albeit a bit hacky, this test uses the list-of-dicts data source for
63+
simplicity and uses the ensemble member dimension to represent stations.
64+
"""
65+
config = {
66+
"sim": {
67+
"source": simulation_source_config,
68+
"coords": {"s": "number", "t": "forecast_reference_time"},
69+
},
70+
"obs": {
71+
"source": observation_source_config,
72+
"to_xarray_options": {"time_dim_mode": "valid_time"},
73+
"coords": {"s": "number", "t": "valid_time"},
74+
},
75+
"stats": ["mae", "correlation"],
76+
"output": {"coords": {"s": "station", "t": "time"}},
77+
}
78+
79+
result = stat_calc(config)
80+
81+
assert set(result["station"].to_numpy()) == {1, 2}
82+
np.testing.assert_allclose(result["mae"].to_numpy().ravel(), [0.0, 1.0])
83+
np.testing.assert_allclose(result["correlation"].to_numpy().ravel(), [1.0, -1.0])
84+
85+
86+
def test_stat_calc_with_nans():
87+
"""Test common NaN handling scenarios in stat_calc.
88+
89+
NOTE: This test documents current behavior, not necessarily desired behavior.
90+
91+
This test creates simulated and observed data for three stations over two time points,
92+
with NaN values in various configurations:
93+
- Station 1: Overlapping valid data points, only one data point is NaN in obs
94+
- Station 2: No overlapping valid data points (one NaN in sim, one NaN in obs)
95+
- Station 3: One station with only NaN values in sim and one nan in obs
96+
"""
97+
sim_source = _wrap_lod(
98+
[
99+
{"date": 20240101, "time": 0, "number": 1, "param": "q", "values": [1.0]},
100+
{"date": 20240102, "time": 0, "number": 1, "param": "q", "values": [-1.0]},
101+
{"date": 20240101, "time": 0, "number": 2, "param": "q", "values": [1.0]},
102+
{
103+
"date": 20240102,
104+
"time": 0,
105+
"number": 2,
106+
"param": "q",
107+
"values": [np.nan],
108+
},
109+
{
110+
"date": 20240101,
111+
"time": 0,
112+
"number": 3,
113+
"param": "q",
114+
"values": [np.nan],
115+
},
116+
{
117+
"date": 20240102,
118+
"time": 0,
119+
"number": 3,
120+
"param": "q",
121+
"values": [np.nan],
122+
},
123+
]
124+
)
125+
obs_source = _wrap_lod(
126+
[
127+
{
128+
"date": 20240101,
129+
"time": 0,
130+
"number": 1,
131+
"param": "q",
132+
"values": [1.0],
133+
},
134+
{
135+
"date": 20240102,
136+
"time": 0,
137+
"number": 1,
138+
"param": "q",
139+
"values": [np.nan],
140+
},
141+
{
142+
"date": 20240101,
143+
"time": 0,
144+
"number": 2,
145+
"param": "q",
146+
"values": [np.nan],
147+
},
148+
{"date": 20240102, "time": 0, "number": 2, "param": "q", "values": [2.0]},
149+
{
150+
"date": 20240101,
151+
"time": 0,
152+
"number": 3,
153+
"param": "q",
154+
"values": [np.nan],
155+
},
156+
{"date": 20240102, "time": 0, "number": 3, "param": "q", "values": [1.0]},
157+
]
158+
)
159+
config = {
160+
"sim": {
161+
"source": sim_source,
162+
"coords": {"s": "number", "t": "forecast_reference_time"},
163+
},
164+
"obs": {
165+
"source": obs_source,
166+
"coords": {"s": "number", "t": "forecast_reference_time"},
167+
},
168+
"stats": ["mae", "br"],
169+
"output": {"coords": {"s": "station", "t": "time"}},
170+
}
171+
172+
result = stat_calc(config)
173+
174+
# Station 1: sim=[1.0, -1.0], obs=[1.0, nan] -> mae=0.0, br=0.0
175+
np.testing.assert_allclose(result["mae"].sel(station=1).to_numpy(), 0.0)
176+
np.testing.assert_allclose(result["br"].sel(station=1).to_numpy(), 0.0)
177+
178+
# Station 2: sim=[1.0, nan], obs=[nan, 2.0] -> mae=nan, br=0.5
179+
assert np.isnan(result["mae"].sel(station=2).to_numpy())
180+
np.testing.assert_allclose(result["br"].sel(station=2).to_numpy(), 0.5)
181+
182+
# Station 3: sim=[nan, nan], obs=[nan, 1.0] -> mae=nan, br=nan
183+
assert np.isnan(result["mae"].sel(station=3).to_numpy())
184+
assert np.isnan(result["br"].sel(station=3).to_numpy())

0 commit comments

Comments
 (0)