Skip to content

Commit 34612c5

Browse files
adapt from_hdf5 and write_hdf5
1 parent a386d22 commit 34612c5

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

climada/hazard/io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,10 @@ def write_hdf5(self, file_name, todense=False):
917917
# Centroids have their own write_hdf5 method,
918918
# which is invoked at the end of this method (s.b.)
919919
continue
920+
elif var_name == "lead_time":
921+
hf_data.create_dataset(
922+
var_name, data=var_val.astype("timedelta64[ns]").astype("int64")
923+
)
920924
elif isinstance(var_val, sparse.csr_matrix):
921925
if todense:
922926
hf_data.create_dataset(var_name, data=var_val.toarray())
@@ -987,7 +991,11 @@ def from_hdf5(cls, file_name):
987991
continue
988992
if var_name == "centroids":
989993
continue
990-
if isinstance(var_val, np.ndarray) and var_val.ndim == 1:
994+
if var_name == "lead_time":
995+
hazard_kwargs[var_name] = np.array(hf_data.get(var_name)).astype(
996+
"timedelta64[ns]"
997+
)
998+
elif isinstance(var_val, np.ndarray) and var_val.ndim == 1:
991999
hazard_kwargs[var_name] = np.array(hf_data.get(var_name))
9921000
elif isinstance(var_val, sparse.csr_matrix):
9931001
hf_csr = hf_data.get(var_name)

climada/hazard/test/test_forecast.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,19 @@ def test_hazard_forecast_select(haz_fc, lead_time, member):
107107
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([3, 0])])
108108
npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])])
109109
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])])
110+
111+
112+
def test_write_read_hazard_forecast(haz_fc, tmp_path):
113+
114+
file_name = tmp_path / "test_hazard_forecast.h5"
115+
116+
haz_fc.write_hdf5(file_name)
117+
haz_fc_read = HazardForecast.from_hdf5(file_name)
118+
119+
assert haz_fc_read.lead_time.dtype == np.dtype("timedelta64[ns]")
120+
assert haz_fc_read.member.dtype == int
121+
for key in haz_fc.__dict__.keys():
122+
if key not in ["intensity", "fraction"]:
123+
npt.assert_array_equal(haz_fc.__dict__[key], haz_fc_read.__dict__[key])
124+
else:
125+
(haz_fc.__dict__[key] != haz_fc_read.__dict__[key]).nnz == 0

0 commit comments

Comments
 (0)