Skip to content

Commit 3d7e37a

Browse files
committed
tests for xarray hazard fc reader
1 parent d01aef4 commit 3d7e37a

File tree

1 file changed

+159
-1
lines changed

1 file changed

+159
-1
lines changed

climada/hazard/test/test_forecast.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919
Tests for Hazard Forecast.
2020
"""
2121

22+
import datetime as dt
23+
from pathlib import Path
24+
2225
import numpy as np
2326
import numpy.testing as npt
2427
import pandas as pd
2528
import pytest
29+
import xarray as xr
2630
from scipy.sparse import csr_matrix
2731

2832
from climada.hazard.base import Hazard
@@ -85,7 +89,20 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs):
8589
assert isinstance(haz_fc_from_haz, HazardForecast)
8690
npt.assert_array_equal(haz_fc_from_haz.lead_time, lead_time)
8791
npt.assert_array_equal(haz_fc_from_haz.member, member)
88-
assert_hazard_kwargs(haz_fc_from_haz, **haz_kwargs)
92+
93+
# Check most hazard kwargs (excluding event_name and date which are auto-generated)
94+
check_kwargs = {
95+
k: v for k, v in haz_kwargs.items() if k not in ["event_name", "date"]
96+
}
97+
assert_hazard_kwargs(haz_fc_from_haz, **check_kwargs)
98+
99+
# Check that event_name and date are auto-generated from lead_time and member
100+
assert len(haz_fc_from_haz.event_name) == len(lead_time)
101+
assert len(haz_fc_from_haz.date) == len(lead_time)
102+
# Date should be all zeros for forecast
103+
npt.assert_array_equal(haz_fc_from_haz.date, np.zeros(len(lead_time), dtype=int))
104+
# Event names should be formatted with lead_time and member
105+
assert haz_fc_from_haz.event_name[0] == f"lt_1h_m_{member[0]}"
89106

90107

91108
def test_hazard_forecast_select(haz_fc, lead_time, member):
@@ -95,3 +112,144 @@ def test_hazard_forecast_select(haz_fc, lead_time, member):
95112
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([3, 0])])
96113
npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])])
97114
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])])
115+
116+
117+
@pytest.fixture(scope="module")
118+
def forecast_netcdf_file(tmp_path_factory):
119+
"""Create a NetCDF file with forecast data structure"""
120+
tmpdir = tmp_path_factory.mktemp("forecast_data")
121+
netcdf_path = tmpdir / "forecast_data.nc"
122+
123+
n_eps = 5
124+
n_lead_time = 4
125+
n_lat = 3
126+
n_lon = 4
127+
128+
eps = np.array([3, 8, 13, 16, 20])
129+
ref_time = np.array([dt.datetime(2025, 12, 8, 6, 0, 0)], dtype="datetime64[ns]")
130+
lead_time_vals = pd.timedelta_range("3h", periods=n_lead_time, freq="2h").to_numpy()
131+
lon = np.array([10.0, 10.5, 11.0, 11.5])
132+
lat = np.array([45.0, 45.5, 46.0])
133+
134+
valid_time = ref_time[0] + lead_time_vals
135+
136+
np.random.seed(42)
137+
intensity = np.random.rand(n_eps, 1, n_lead_time, n_lat, n_lon) * 10
138+
139+
# Create xarray Dataset
140+
dset = xr.Dataset(
141+
{
142+
"__xarray_dataarray_variable__": (
143+
["eps", "ref_time", "lead_time", "lat", "lon"],
144+
intensity,
145+
),
146+
},
147+
coords={
148+
"eps": eps,
149+
"ref_time": ref_time,
150+
"lead_time": lead_time_vals,
151+
"lon": lon,
152+
"lat": lat,
153+
"valid_time": (["lead_time"], valid_time),
154+
},
155+
)
156+
dset.to_netcdf(netcdf_path)
157+
158+
return {
159+
"path": netcdf_path,
160+
"n_eps": n_eps,
161+
"n_lead_time": n_lead_time,
162+
"n_lat": n_lat,
163+
"n_lon": n_lon,
164+
"eps": eps,
165+
"lead_time": lead_time_vals,
166+
"lon": lon,
167+
"lat": lat,
168+
}
169+
170+
171+
def test_from_xarray_raster_basic(forecast_netcdf_file):
172+
"""Test basic loading of forecast hazard from xarray"""
173+
haz_fc = HazardForecast.from_xarray_raster(
174+
forecast_netcdf_file["path"],
175+
hazard_type="PR",
176+
intensity_unit="mm/h",
177+
coordinate_vars={
178+
"longitude": "lon",
179+
"latitude": "lat",
180+
"lead_time": "lead_time",
181+
"member": "eps",
182+
},
183+
)
184+
185+
# Check that it's a HazardForecast instance
186+
assert isinstance(haz_fc, HazardForecast)
187+
188+
# Check dimensions - after stacking, we should have n_eps * n_lead_time events
189+
expected_n_events = (
190+
forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"]
191+
)
192+
assert len(haz_fc.event_id) == expected_n_events
193+
assert len(haz_fc.lead_time) == expected_n_events
194+
assert len(haz_fc.member) == expected_n_events
195+
196+
# Check that lead_time and member are correctly extracted
197+
npt.assert_array_equal(np.unique(haz_fc.member), forecast_netcdf_file["eps"])
198+
199+
# Check intensity shape (events x centroids)
200+
expected_n_centroids = forecast_netcdf_file["n_lat"] * forecast_netcdf_file["n_lon"]
201+
assert haz_fc.intensity.shape == (expected_n_events, expected_n_centroids)
202+
203+
# Check centroids
204+
assert len(haz_fc.centroids.lat) == expected_n_centroids
205+
assert len(haz_fc.centroids.lon) == expected_n_centroids
206+
207+
208+
def test_from_xarray_raster_event_names(forecast_netcdf_file):
209+
"""Test that event names are auto-generated from lead_time and member"""
210+
haz_fc = HazardForecast.from_xarray_raster(
211+
forecast_netcdf_file["path"],
212+
hazard_type="PR",
213+
intensity_unit="mm/h",
214+
coordinate_vars={
215+
"longitude": "lon",
216+
"latitude": "lat",
217+
"lead_time": "lead_time",
218+
"member": "eps",
219+
},
220+
)
221+
222+
# Check that event names are generated with lead_time in hours
223+
expected_n_events = (
224+
forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"]
225+
)
226+
assert len(haz_fc.event_name) == expected_n_events
227+
228+
# First event should be for first lead_time and first member
229+
# Lead time should be in hours (e.g., "lt_3h_m_3")
230+
first_lead_hours = forecast_netcdf_file["lead_time"][0] / np.timedelta64(1, "h")
231+
expected_first_name = (
232+
f"lt_{first_lead_hours:.0f}h_m_{forecast_netcdf_file['eps'][0]}"
233+
)
234+
assert haz_fc.event_name[0] == expected_first_name
235+
236+
237+
def test_from_xarray_raster_dates(forecast_netcdf_file):
238+
"""Test that dates are set to 0 for forecast events"""
239+
haz_fc = HazardForecast.from_xarray_raster(
240+
forecast_netcdf_file["path"],
241+
hazard_type="PR",
242+
intensity_unit="mm/h",
243+
coordinate_vars={
244+
"longitude": "lon",
245+
"latitude": "lat",
246+
"lead_time": "lead_time",
247+
"member": "eps",
248+
},
249+
)
250+
251+
# Check that all dates are 0 (undefined for forecast)
252+
expected_n_events = (
253+
forecast_netcdf_file["n_eps"] * forecast_netcdf_file["n_lead_time"]
254+
)
255+
npt.assert_array_equal(haz_fc.date, np.zeros(expected_n_events, dtype=int))

0 commit comments

Comments
 (0)