Skip to content

Commit 3e7efbe

Browse files
committed
Update tests
1 parent 0c30e30 commit 3e7efbe

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

climada/engine/test/test_impact_forecast.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -167,24 +167,31 @@ def test_impact_forecast_blocked_methods(impact_forecast):
167167
impact_forecast.calc_freq_curve(np.array([10, 50, 100]))
168168

169169

170-
def test_write_read_impact_forecast(impact_forecast, tmp_path):
170+
@pytest.mark.parametrize("dense", [True, False])
171+
def test_write_read_hdf5(impact_forecast, tmp_path, dense):
171172

172173
file_name = tmp_path / "test_hazard_forecast.h5"
173174
# replace dummy_impact event_names with strings
174175
impact_forecast.event_name = [str(name) for name in impact_forecast.event_name]
176+
impact_forecast.write_hdf5(file_name, dense_imp_mat=dense)
175177

176-
impact_forecast.write_hdf5(file_name)
177-
impact_forecast_read = ImpactForecast.from_hdf5(file_name)
178+
def compare_attr(obj, attr):
179+
actual = getattr(obj, attr)
180+
expected = getattr(impact_forecast, attr)
181+
if isinstance(actual, csr_matrix):
182+
npt.assert_array_equal(actual.todense(), expected.todense())
183+
else:
184+
npt.assert_array_equal(actual, expected)
178185

186+
# Read ImpactForecast
187+
impact_forecast_read = ImpactForecast.from_hdf5(file_name)
179188
assert impact_forecast_read.lead_time.dtype.kind == np.dtype("timedelta64").kind
180-
181-
for key in impact_forecast.__dict__.keys():
182-
if key in ["imp_mat"]:
183-
(
184-
impact_forecast.__dict__[key] != impact_forecast_read.__dict__[key]
185-
).nnz == 0
186-
else:
187-
# npt.assert_array_equal also works for comparing int, float or list
188-
npt.assert_array_equal(
189-
impact_forecast.__dict__[key], impact_forecast_read.__dict__[key]
190-
)
189+
for attr in impact_forecast.__dict__.keys():
190+
compare_attr(impact_forecast_read, attr)
191+
192+
# Read Impact
193+
impact_read = Impact.from_hdf5(file_name)
194+
for attr in impact_read.__dict__.keys():
195+
compare_attr(impact_read, attr)
196+
assert "member" not in impact_read.__dict__
197+
assert "lead_time" not in impact_read.__dict__

0 commit comments

Comments
 (0)