1919Tests for Hazard Forecast.
2020"""
2121
22+ import datetime as dt
23+ from pathlib import Path
24+
2225import numpy as np
2326import numpy .testing as npt
2427import pandas as pd
2528import pytest
29+ import xarray as xr
2630from scipy .sparse import csr_matrix
2731
2832from climada .hazard .base import Hazard
@@ -85,7 +89,20 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs):
8589 assert isinstance (haz_fc_from_haz , HazardForecast )
8690 npt .assert_array_equal (haz_fc_from_haz .lead_time , lead_time )
8791 npt .assert_array_equal (haz_fc_from_haz .member , member )
88- assert_hazard_kwargs (haz_fc_from_haz , ** haz_kwargs )
92+
93+ # Check most hazard kwargs (excluding event_name and date which are auto-generated)
94+ check_kwargs = {
95+ k : v for k , v in haz_kwargs .items () if k not in ["event_name" , "date" ]
96+ }
97+ assert_hazard_kwargs (haz_fc_from_haz , ** check_kwargs )
98+
99+ # Check that event_name and date are auto-generated from lead_time and member
100+ assert len (haz_fc_from_haz .event_name ) == len (lead_time )
101+ assert len (haz_fc_from_haz .date ) == len (lead_time )
102+ # Date should be all zeros for forecast
103+ npt .assert_array_equal (haz_fc_from_haz .date , np .zeros (len (lead_time ), dtype = int ))
104+ # Event names should be formatted with lead_time and member
105+ assert haz_fc_from_haz .event_name [0 ] == f"lt_1h_m_{ member [0 ]} "
89106
90107
91108def test_hazard_forecast_select (haz_fc , lead_time , member ):
@@ -95,3 +112,144 @@ def test_hazard_forecast_select(haz_fc, lead_time, member):
95112 npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [np .array ([3 , 0 ])])
96113 npt .assert_array_equal (haz_fc_select .member , member [np .array ([3 , 0 ])])
97114 npt .assert_array_equal (haz_fc_select .lead_time , lead_time [np .array ([3 , 0 ])])
115+
116+
117+ @pytest .fixture (scope = "module" )
118+ def forecast_netcdf_file (tmp_path_factory ):
119+ """Create a NetCDF file with forecast data structure"""
120+ tmpdir = tmp_path_factory .mktemp ("forecast_data" )
121+ netcdf_path = tmpdir / "forecast_data.nc"
122+
123+ n_eps = 5
124+ n_lead_time = 4
125+ n_lat = 3
126+ n_lon = 4
127+
128+ eps = np .array ([3 , 8 , 13 , 16 , 20 ])
129+ ref_time = np .array ([dt .datetime (2025 , 12 , 8 , 6 , 0 , 0 )], dtype = "datetime64[ns]" )
130+ lead_time_vals = pd .timedelta_range ("3h" , periods = n_lead_time , freq = "2h" ).to_numpy ()
131+ lon = np .array ([10.0 , 10.5 , 11.0 , 11.5 ])
132+ lat = np .array ([45.0 , 45.5 , 46.0 ])
133+
134+ valid_time = ref_time [0 ] + lead_time_vals
135+
136+ np .random .seed (42 )
137+ intensity = np .random .rand (n_eps , 1 , n_lead_time , n_lat , n_lon ) * 10
138+
139+ # Create xarray Dataset
140+ dset = xr .Dataset (
141+ {
142+ "__xarray_dataarray_variable__" : (
143+ ["eps" , "ref_time" , "lead_time" , "lat" , "lon" ],
144+ intensity ,
145+ ),
146+ },
147+ coords = {
148+ "eps" : eps ,
149+ "ref_time" : ref_time ,
150+ "lead_time" : lead_time_vals ,
151+ "lon" : lon ,
152+ "lat" : lat ,
153+ "valid_time" : (["lead_time" ], valid_time ),
154+ },
155+ )
156+ dset .to_netcdf (netcdf_path )
157+
158+ return {
159+ "path" : netcdf_path ,
160+ "n_eps" : n_eps ,
161+ "n_lead_time" : n_lead_time ,
162+ "n_lat" : n_lat ,
163+ "n_lon" : n_lon ,
164+ "eps" : eps ,
165+ "lead_time" : lead_time_vals ,
166+ "lon" : lon ,
167+ "lat" : lat ,
168+ }
169+
170+
171+ def test_from_xarray_raster_basic (forecast_netcdf_file ):
172+ """Test basic loading of forecast hazard from xarray"""
173+ haz_fc = HazardForecast .from_xarray_raster (
174+ forecast_netcdf_file ["path" ],
175+ hazard_type = "PR" ,
176+ intensity_unit = "mm/h" ,
177+ coordinate_vars = {
178+ "longitude" : "lon" ,
179+ "latitude" : "lat" ,
180+ "lead_time" : "lead_time" ,
181+ "member" : "eps" ,
182+ },
183+ )
184+
185+ # Check that it's a HazardForecast instance
186+ assert isinstance (haz_fc , HazardForecast )
187+
188+ # Check dimensions - after stacking, we should have n_eps * n_lead_time events
189+ expected_n_events = (
190+ forecast_netcdf_file ["n_eps" ] * forecast_netcdf_file ["n_lead_time" ]
191+ )
192+ assert len (haz_fc .event_id ) == expected_n_events
193+ assert len (haz_fc .lead_time ) == expected_n_events
194+ assert len (haz_fc .member ) == expected_n_events
195+
196+ # Check that lead_time and member are correctly extracted
197+ npt .assert_array_equal (np .unique (haz_fc .member ), forecast_netcdf_file ["eps" ])
198+
199+ # Check intensity shape (events x centroids)
200+ expected_n_centroids = forecast_netcdf_file ["n_lat" ] * forecast_netcdf_file ["n_lon" ]
201+ assert haz_fc .intensity .shape == (expected_n_events , expected_n_centroids )
202+
203+ # Check centroids
204+ assert len (haz_fc .centroids .lat ) == expected_n_centroids
205+ assert len (haz_fc .centroids .lon ) == expected_n_centroids
206+
207+
208+ def test_from_xarray_raster_event_names (forecast_netcdf_file ):
209+ """Test that event names are auto-generated from lead_time and member"""
210+ haz_fc = HazardForecast .from_xarray_raster (
211+ forecast_netcdf_file ["path" ],
212+ hazard_type = "PR" ,
213+ intensity_unit = "mm/h" ,
214+ coordinate_vars = {
215+ "longitude" : "lon" ,
216+ "latitude" : "lat" ,
217+ "lead_time" : "lead_time" ,
218+ "member" : "eps" ,
219+ },
220+ )
221+
222+ # Check that event names are generated with lead_time in hours
223+ expected_n_events = (
224+ forecast_netcdf_file ["n_eps" ] * forecast_netcdf_file ["n_lead_time" ]
225+ )
226+ assert len (haz_fc .event_name ) == expected_n_events
227+
228+ # First event should be for first lead_time and first member
229+ # Lead time should be in hours (e.g., "lt_3h_m_3")
230+ first_lead_hours = forecast_netcdf_file ["lead_time" ][0 ] / np .timedelta64 (1 , "h" )
231+ expected_first_name = (
232+ f"lt_{ first_lead_hours :.0f} h_m_{ forecast_netcdf_file ['eps' ][0 ]} "
233+ )
234+ assert haz_fc .event_name [0 ] == expected_first_name
235+
236+
237+ def test_from_xarray_raster_dates (forecast_netcdf_file ):
238+ """Test that dates are set to 0 for forecast events"""
239+ haz_fc = HazardForecast .from_xarray_raster (
240+ forecast_netcdf_file ["path" ],
241+ hazard_type = "PR" ,
242+ intensity_unit = "mm/h" ,
243+ coordinate_vars = {
244+ "longitude" : "lon" ,
245+ "latitude" : "lat" ,
246+ "lead_time" : "lead_time" ,
247+ "member" : "eps" ,
248+ },
249+ )
250+
251+ # Check that all dates are 0 (undefined for forecast)
252+ expected_n_events = (
253+ forecast_netcdf_file ["n_eps" ] * forecast_netcdf_file ["n_lead_time" ]
254+ )
255+ npt .assert_array_equal (haz_fc .date , np .zeros (expected_n_events , dtype = int ))
0 commit comments