@@ -167,24 +167,31 @@ def test_impact_forecast_blocked_methods(impact_forecast):
167167 impact_forecast .calc_freq_curve (np .array ([10 , 50 , 100 ]))
168168
169169
170- def test_write_read_impact_forecast (impact_forecast , tmp_path ):
170+ @pytest .mark .parametrize ("dense" , [True , False ])
171+ def test_write_read_hdf5 (impact_forecast , tmp_path , dense ):
171172
172173 file_name = tmp_path / "test_hazard_forecast.h5"
173174 # replace dummy_impact event_names with strings
174175 impact_forecast .event_name = [str (name ) for name in impact_forecast .event_name ]
176+ impact_forecast .write_hdf5 (file_name , dense_imp_mat = dense )
175177
176- impact_forecast .write_hdf5 (file_name )
177- impact_forecast_read = ImpactForecast .from_hdf5 (file_name )
178+ def compare_attr (obj , attr ):
179+ actual = getattr (obj , attr )
180+ expected = getattr (impact_forecast , attr )
181+ if isinstance (actual , csr_matrix ):
182+ npt .assert_array_equal (actual .todense (), expected .todense ())
183+ else :
184+ npt .assert_array_equal (actual , expected )
178185
186+ # Read ImpactForecast
187+ impact_forecast_read = ImpactForecast .from_hdf5 (file_name )
179188 assert impact_forecast_read .lead_time .dtype .kind == np .dtype ("timedelta64" ).kind
180-
181- for key in impact_forecast .__dict__ .keys ():
182- if key in ["imp_mat" ]:
183- (
184- impact_forecast .__dict__ [key ] != impact_forecast_read .__dict__ [key ]
185- ).nnz == 0
186- else :
187- # npt.assert_array_equal also works for comparing int, float or list
188- npt .assert_array_equal (
189- impact_forecast .__dict__ [key ], impact_forecast_read .__dict__ [key ]
190- )
189+ for attr in impact_forecast .__dict__ .keys ():
190+ compare_attr (impact_forecast_read , attr )
191+
192+ # Read Impact
193+ impact_read = Impact .from_hdf5 (file_name )
194+ for attr in impact_read .__dict__ .keys ():
195+ compare_attr (impact_read , attr )
196+ assert "member" not in impact_read .__dict__
197+ assert "lead_time" not in impact_read .__dict__
0 commit comments