Skip to content

Commit 0c30e30

Browse files
implement changes in Impact IO and ImpactForecast IO
1 parent 9d6fef9 commit 0c30e30

File tree

3 files changed

+107
-4
lines changed

3 files changed

+107
-4
lines changed

climada/engine/impact.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,8 @@ def write_attribute(group, name, value):
14311431

14321432
def write_dataset(group, name, value):
14331433
"""Write a dataset"""
1434+
if name == "lead_time":
1435+
value = value.astype("timedelta64[ns]").astype("int64")
14341436
group.create_dataset(name, data=value, dtype=_str_type_helper(value))
14351437

14361438
def write_dict(group, name, value):
@@ -1618,7 +1620,9 @@ def read_excel(self, *args, **kwargs):
16181620
self.__dict__ = Impact.from_excel(*args, **kwargs).__dict__
16191621

16201622
@classmethod
1621-
def from_hdf5(cls, file_path: Union[str, Path]):
1623+
def from_hdf5(
1624+
cls, file_path: Union[str, Path], *, add_scalar_attrs=None, add_array_attrs=None
1625+
):
16221626
"""Create an impact object from an H5 file.
16231627
16241628
This assumes a specific layout of the file. If values are not found in the
@@ -1663,6 +1667,10 @@ def from_hdf5(cls, file_path: Union[str, Path]):
16631667
----------
16641668
file_path : str or Path
16651669
The file path of the file to read.
1670+
add_scalar_attrs : Iterable of str, optional
1671+
Scalar attributes to read from file. Defaults to None.
1672+
add_array_attrs : Iterable of str, optional
1673+
Array attributes to read from file. Defaults to None.
16661674
16671675
Returns
16681676
-------
@@ -1691,17 +1699,31 @@ def from_hdf5(cls, file_path: Union[str, Path]):
16911699
# Scalar attributes
16921700
scalar_attrs = set(
16931701
("crs", "tot_value", "unit", "aai_agg", "frequency_unit", "haz_type")
1694-
).intersection(file.attrs.keys())
1702+
)
1703+
if add_scalar_attrs is not None:
1704+
scalar_attrs = scalar_attrs.union(add_scalar_attrs)
1705+
scalar_attrs = scalar_attrs.intersection(file.attrs.keys())
16951706
kwargs.update({attr: file.attrs[attr] for attr in scalar_attrs})
16961707

16971708
# Array attributes
16981709
# NOTE: Need [:] to copy array data. Otherwise, it would be a view that is
16991710
# invalidated once we close the file.
17001711
array_attrs = set(
17011712
("event_id", "date", "coord_exp", "eai_exp", "at_event", "frequency")
1702-
).intersection(file.keys())
1713+
)
1714+
if add_array_attrs is not None:
1715+
array_attrs = array_attrs.union(add_array_attrs)
1716+
array_attrs = array_attrs.intersection(file.keys())
17031717
kwargs.update({attr: file[attr][:] for attr in array_attrs})
1704-
1718+
# correct lead_time attribut to timedelta
1719+
if "lead_time" in array_attrs:
1720+
kwargs.update(
1721+
{
1722+
"lead_time": np.array(file["lead_time"][:]).astype(
1723+
"timedelta64[ns]"
1724+
)
1725+
}
1726+
)
17051727
# Special handling for 'event_name' because it should be a list of strings
17061728
if "event_name" in file:
17071729
# pylint: disable=no-member

climada/engine/impact_forecast.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"""
2121

2222
import logging
23+
from pathlib import Path
24+
from typing import Union
2325

2426
import numpy as np
2527

@@ -172,6 +174,62 @@ def calc_freq_curve(self, return_per=None):
172174
LOGGER.error("calc_freq_curve is not defined for ImpactForecast")
173175
raise NotImplementedError("calc_freq_curve is not defined for ImpactForecast")
174176

177+
@classmethod
178+
def from_hdf5(cls, file_path: Union[str, Path]):
179+
"""Create an ImpactForecast object from an H5 file.
180+
181+
This assumes a specific layout of the file. If values are not found in the
182+
expected places, they will be set to the default values for an ``Impact`` object.
183+
184+
The following H5 file structure is assumed (H5 groups are terminated with ``/``,
185+
attributes are denoted by ``.attrs/``)::
186+
187+
file.h5
188+
├─ at_event
189+
├─ coord_exp
190+
├─ eai_exp
191+
├─ event_id
192+
├─ event_name
193+
├─ frequency
194+
├─ imp_mat
195+
├─ lead_time
196+
├─ member
197+
├─ .attrs/
198+
│ ├─ aai_agg
199+
│ ├─ crs
200+
│ ├─ frequency_unit
201+
│ ├─ haz_type
202+
│ ├─ tot_value
203+
│ ├─ unit
204+
205+
As per the :py:func:`climada.engine.impact.Impact.__init__`, any of these entries
206+
is optional. If it is not found, the default value will be used when constructing
207+
the Impact.
208+
209+
The impact matrix ``imp_mat`` can either be an H5 dataset, in which case it is
210+
interpreted as dense representation of the matrix, or an H5 group, in which case
211+
the group is expected to contain the following data for instantiating a
212+
`scipy.sparse.csr_matrix <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html>`_::
213+
214+
imp_mat/
215+
├─ data
216+
├─ indices
217+
├─ indptr
218+
├─ .attrs/
219+
│ ├─ shape
220+
221+
Parameters
222+
----------
223+
file_path : str or Path
224+
The file path of the file to read.
225+
226+
Returns
227+
-------
228+
imp : ImpactForecast
229+
ImpactForecast with data from the given file
230+
"""
231+
return super().from_hdf5(file_path, add_array_attrs={"member", "lead_time"})
232+
175233
def _check_sizes(self):
176234
"""Check sizes of forecast data vs. impact data.
177235

climada/engine/test/test_impact_forecast.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,26 @@ def test_impact_forecast_blocked_methods(impact_forecast):
165165

166166
with pytest.raises(NotImplementedError):
167167
impact_forecast.calc_freq_curve(np.array([10, 50, 100]))
168+
169+
170+
def test_write_read_impact_forecast(impact_forecast, tmp_path):
171+
172+
file_name = tmp_path / "test_hazard_forecast.h5"
173+
# replace dummy_impact event_names with strings
174+
impact_forecast.event_name = [str(name) for name in impact_forecast.event_name]
175+
176+
impact_forecast.write_hdf5(file_name)
177+
impact_forecast_read = ImpactForecast.from_hdf5(file_name)
178+
179+
assert impact_forecast_read.lead_time.dtype.kind == np.dtype("timedelta64").kind
180+
181+
for key in impact_forecast.__dict__.keys():
182+
if key in ["imp_mat"]:
183+
(
184+
impact_forecast.__dict__[key] != impact_forecast_read.__dict__[key]
185+
).nnz == 0
186+
else:
187+
# npt.assert_array_equal also works for comparing int, float or list
188+
npt.assert_array_equal(
189+
impact_forecast.__dict__[key], impact_forecast_read.__dict__[key]
190+
)

0 commit comments

Comments
 (0)