@@ -41,13 +41,13 @@ def hazard(haz_kwargs):
4141
4242
4343@pytest .fixture
44- def lead_time ():
45- return pd .timedelta_range ("1h" , periods = 6 ).to_numpy ()
44+ def lead_time (haz_kwargs ):
45+ return pd .timedelta_range ("1h" , periods = len ( haz_kwargs [ "event_id" ]) ).to_numpy ()
4646
4747
4848@pytest .fixture
49- def member ():
50- return np .arange (6 )
49+ def member (haz_kwargs ):
50+ return np .arange (len ( haz_kwargs [ "event_id" ]) )
5151
5252
5353@pytest .fixture
@@ -78,6 +78,13 @@ def test_init_hazard_forecast(haz_fc, member, lead_time, haz_kwargs):
7878 assert_hazard_kwargs (haz_fc , ** haz_kwargs )
7979
8080
81+ def test_init_hazard_forecast_error (hazard , member , lead_time , haz_kwargs ):
82+ with pytest .raises (ValueError , match = "Forecast.lead_time" ):
83+ HazardForecast (lead_time = lead_time [:- 2 ], member = member , ** haz_kwargs )
84+ with pytest .raises (ValueError , match = "Forecast.member" ):
85+ HazardForecast .from_hazard (hazard , lead_time = lead_time , member = member [1 :])
86+
87+
8188def test_from_hazard (lead_time , member , hazard , haz_kwargs ):
8289 haz_fc_from_haz = HazardForecast .from_hazard (
8390 hazard , lead_time = lead_time , member = member
@@ -100,10 +107,60 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member):
100107 npt .assert_array_equal (haz_fc_concat .member , np .concatenate ([member , member ]))
101108
102109
103- def test_hazard_forecast_select (haz_fc , lead_time , member ):
110+ @pytest .mark .parametrize (
111+ "var, var_select" ,
112+ [("event_id" , "event_id" ), ("event_name" , "event_names" ), ("date" , "date" )],
113+ )
114+ def test_hazard_forecast_select (haz_fc , lead_time , member , haz_kwargs , var , var_select ):
104115 """Check if Hazard.select works on the derived class"""
105- haz_fc_select = haz_fc .select (event_id = [4 , 1 ])
106- # NOTE: Events keep their original order
107- npt .assert_array_equal (haz_fc_select .event_id , haz_fc .event_id [np .array ([3 , 0 ])])
108- npt .assert_array_equal (haz_fc_select .member , member [np .array ([3 , 0 ])])
109- npt .assert_array_equal (haz_fc_select .lead_time , lead_time [np .array ([3 , 0 ])])
116+
117+ select_mask = np .array ([3 , 2 ])
118+ ordered_select_mask = np .array ([3 , 2 ])
119+ if var == "date" :
120+ # Date needs to be a valid delta
121+ select_mask = np .array ([2 , 3 ])
122+ ordered_select_mask = np .array ([2 , 3 ])
123+
124+ var_value = np .array (haz_kwargs [var ])[select_mask ]
125+ # event_name is a list, convert to numpy array for indexing
126+ haz_fc_sel = haz_fc .select (** {var_select : var_value })
127+ # Note: order is preserved
128+ npt .assert_array_equal (
129+ haz_fc_sel .event_id ,
130+ haz_fc .event_id [ordered_select_mask ],
131+ )
132+ npt .assert_array_equal (
133+ haz_fc_sel .event_name ,
134+ np .array (haz_fc .event_name )[ordered_select_mask ],
135+ )
136+ npt .assert_array_equal (haz_fc_sel .date , haz_fc .date [ordered_select_mask ])
137+ npt .assert_array_equal (haz_fc_sel .frequency , haz_fc .frequency [ordered_select_mask ])
138+ npt .assert_array_equal (haz_fc_sel .member , member [ordered_select_mask ])
139+ npt .assert_array_equal (haz_fc_sel .lead_time , lead_time [ordered_select_mask ])
140+ npt .assert_array_equal (
141+ haz_fc_sel .intensity .todense (),
142+ haz_fc .intensity .todense ()[ordered_select_mask ],
143+ )
144+ npt .assert_array_equal (
145+ haz_fc_sel .fraction .todense (),
146+ haz_fc .fraction .todense ()[ordered_select_mask ],
147+ )
148+
149+ assert haz_fc_sel .centroids == haz_fc .centroids
150+
151+
152+ def test_write_read_hazard_forecast (haz_fc , tmp_path ):
153+
154+ file_name = tmp_path / "test_hazard_forecast.h5"
155+
156+ haz_fc .write_hdf5 (file_name )
157+ haz_fc_read = HazardForecast .from_hdf5 (file_name )
158+
159+ assert haz_fc_read .lead_time .dtype .kind == np .dtype ("timedelta64" ).kind
160+
161+ for key in haz_fc .__dict__ .keys ():
162+ if key in ["intensity" , "fraction" ]:
163+ (haz_fc .__dict__ [key ] != haz_fc_read .__dict__ [key ]).nnz == 0
164+ else :
165+ # npt.assert_array_equal also works for comparing int, float or list
166+ npt .assert_array_equal (haz_fc .__dict__ [key ], haz_fc_read .__dict__ [key ])
0 commit comments