Skip to content

Commit a55ab3b

Browse files
committed
Merge branch 'forecast-class' into implement_mean_min_max
2 parents 0bc5846 + 9d6fef9 commit a55ab3b

File tree

7 files changed

+272
-35
lines changed

7 files changed

+272
-35
lines changed

climada/engine/impact_forecast.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import scipy.sparse as sparse
2626

2727
from ..util import log_level
28+
from ..util.checker import size
2829
from ..util.forecast import Forecast
2930
from .impact import Impact
3031

@@ -52,8 +53,8 @@ def __init__(
5253
impact_kwargs
5354
Keyword-arguments passed to ~:py:class`climada.engine.impact.Impact`.
5455
"""
55-
# TODO: Maybe assert array lengths?
5656
super().__init__(lead_time=lead_time, member=member, **impact_kwargs)
57+
self._check_sizes()
5758

5859
@classmethod
5960
def from_impact(
@@ -92,13 +93,16 @@ def from_impact(
9293

9394
@property
9495
def at_event(self):
95-
LOGGER.warning("at_event for forecasts is not yet implemented.")
96+
"""Get the total impact for each member/lead_time combination."""
97+
LOGGER.warning(
98+
"at_event gives the total impact for one specific combination of member and "
99+
"lead_time."
100+
)
96101
return self._at_event
97102

98103
@at_event.setter
99104
def at_event(self, value):
100-
"""Set the total exposure value close to a hazard"""
101-
LOGGER.warning("at_event for forecasts is not yet implemented.")
105+
"""Set the total impact for each member/lead_time combination."""
102106
self._at_event = value
103107

104108
def local_exceedance_impact(
@@ -111,9 +115,14 @@ def local_exceedance_impact(
111115
bin_decimals=None,
112116
):
113117
"""Compution of local exceedance impact for given return periods is not
114-
implemented for ImpactForecast. See climada.engine.impact.Impact for details.
115-
Returns
116-
-------
118+
implemented for ImpactForecast.
119+
120+
See Also
121+
--------
122+
See :py:meth:`~climada.engine.impact.Impact.local_exceedance_impact`
123+
124+
Raises
125+
------
117126
NotImplementedError
118127
"""
119128

@@ -132,8 +141,13 @@ def local_return_period(
132141
bin_decimals=None,
133142
):
134143
"""Compution of local return period for given impact thresholds is not
135-
implemented for ImpactForecast. See climada.engine.impact.Impact for details.
136-
Returns
144+
implemented for ImpactForecast.
145+
146+
See Also
147+
--------
148+
See :py:meth:`~climada.engine.impact.Impact.local_return_period`
149+
150+
Raises
137151
-------
138152
NotImplementedError
139153
"""
@@ -145,15 +159,33 @@ def local_return_period(
145159

146160
def calc_freq_curve(self, return_per=None):
147161
"""Computation of the impact exceedance frequency curve is not
148-
implemented for ImpactForecast. See climada.engine.impact.Impact for details.
149-
Returns
150-
-------
162+
implemented for ImpactForecast.
163+
164+
See Also
165+
--------
166+
See :py:meth:`~climada.engine.impact.Impact.calc_freq_curve`
167+
168+
Raises
169+
------
151170
NotImplementedError
152171
"""
153172

154173
LOGGER.error("calc_freq_curve is not defined for ImpactForecast")
155174
raise NotImplementedError("calc_freq_curve is not defined for ImpactForecast")
156175

176+
def _check_sizes(self):
177+
"""Check sizes of forecast data vs. impact data.
178+
179+
Raises
180+
------
181+
ValueError
182+
If the sizes of the forecast data do not match the
183+
:py:attr:`~climada.engine.impact.Impact.event_id`
184+
"""
185+
num_entries = len(self.event_id)
186+
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
187+
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
188+
157189
def _reduce_attrs(self, reduce_method: str):
158190
"""
159191
Reduce the attributes of an ImpactForecast to a single value.

climada/engine/test/test_impact_forecast.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def impact(impact_kwargs):
4141

4242

4343
@pytest.fixture
44-
def lead_time():
45-
return pd.timedelta_range(start="1 day", periods=6).to_numpy()
44+
def lead_time(impact_kwargs):
45+
return pd.timedelta_range(
46+
start="1 day", periods=len(impact_kwargs["event_id"])
47+
).to_numpy()
4648

4749

4850
@pytest.fixture
49-
def member():
50-
return np.arange(6)
51+
def member(impact_kwargs):
52+
return np.arange(len(impact_kwargs["event_id"]))
5153

5254

5355
@pytest.fixture
@@ -76,6 +78,12 @@ def test_impact_forecast_init(self, impact_kwargs, lead_time, member):
7678
npt.assert_array_equal(forecast1.member, member)
7779
self.assert_impact_kwargs(forecast1, **impact_kwargs)
7880

81+
def test_impact_forecast_init_error(self, impact, impact_kwargs, lead_time, member):
82+
with pytest.raises(ValueError, match="Forecast.lead_time"):
83+
ImpactForecast(lead_time=lead_time[:-2], member=member, **impact_kwargs)
84+
with pytest.raises(ValueError, match="Forecast.member"):
85+
ImpactForecast.from_impact(impact, lead_time=lead_time, member=member[1:])
86+
7987
def test_impact_forecast_from_impact(
8088
self, impact_forecast, impact_kwargs, lead_time, member
8189
):
@@ -84,16 +92,58 @@ def test_impact_forecast_from_impact(
8492
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
8593

8694

87-
def test_impact_forecast_select(impact_forecast, lead_time, member, impact_kwargs):
95+
@pytest.mark.parametrize(
96+
"var, var_select",
97+
[("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")],
98+
)
99+
def test_impact_forecast_select_events(
100+
impact_forecast, lead_time, member, impact_kwargs, var, var_select
101+
):
88102
"""Check if Impact.select works on the derived class"""
89-
event_ids = impact_kwargs["event_id"][np.array([2, 0])]
90-
impact_fc = impact_forecast.select(event_ids=event_ids)
103+
select_mask = np.array([2, 1])
104+
ordered_select_mask = np.array([1, 2])
105+
if var == "date":
106+
# Date needs to be a valid delta
107+
select_mask = np.array([1, 2])
108+
ordered_select_mask = np.array([1, 2])
109+
110+
var_value = np.array(impact_kwargs[var])[select_mask]
111+
# event_name is a list, convert to numpy array for indexing
112+
impact_fc = impact_forecast.select(**{var_select: var_value})
91113
# NOTE: Events keep their original order
92114
npt.assert_array_equal(
93-
impact_fc.event_id, impact_forecast.event_id[np.array([0, 2])]
115+
impact_fc.event_id,
116+
impact_forecast.event_id[ordered_select_mask],
117+
)
118+
npt.assert_array_equal(
119+
impact_fc.event_name,
120+
np.array(impact_forecast.event_name)[ordered_select_mask],
121+
)
122+
npt.assert_array_equal(impact_fc.date, impact_forecast.date[ordered_select_mask])
123+
npt.assert_array_equal(
124+
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
125+
)
126+
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
127+
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
128+
npt.assert_array_equal(
129+
impact_fc.imp_mat.todense(),
130+
impact_forecast.imp_mat.todense()[ordered_select_mask],
131+
)
132+
133+
134+
def test_impact_forecast_select_exposure(
135+
impact_forecast, lead_time, member, impact_kwargs
136+
):
137+
"""Check if Impact.select works on the derived class"""
138+
exp_col = 0
139+
select_mask = np.array([exp_col])
140+
coord_exp = impact_kwargs["coord_exp"][select_mask]
141+
impact_fc = impact_forecast.select(coord_exp=coord_exp)
142+
npt.assert_array_equal(impact_fc.member, member)
143+
npt.assert_array_equal(impact_fc.lead_time, lead_time)
144+
npt.assert_array_equal(
145+
impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col]
94146
)
95-
npt.assert_array_equal(impact_fc.member, member[np.array([0, 2])])
96-
npt.assert_array_equal(impact_fc.lead_time, lead_time[np.array([0, 2])])
97147

98148

99149
@pytest.mark.skip("Concat from base class does not work")

climada/hazard/forecast.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
import numpy as np
2525

26-
from climada.hazard.base import Hazard
27-
from climada.util.forecast import Forecast
26+
from ..util.checker import size
27+
from ..util.forecast import Forecast
28+
from .base import Hazard
2829

2930
LOGGER = logging.getLogger(__name__)
3031

@@ -52,6 +53,7 @@ def __init__(
5253
py:meth`~climada.hazard.base.Hazard.__init__` for details.
5354
"""
5455
super().__init__(lead_time=lead_time, member=member, **hazard_kwargs)
56+
self._check_sizes()
5557

5658
@classmethod
5759
def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray):
@@ -89,3 +91,16 @@ def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray):
8991
intensity=hazard.intensity,
9092
fraction=hazard.fraction,
9193
)
94+
95+
def _check_sizes(self):
96+
"""Check sizes of forecast data vs. hazard data.
97+
98+
Raises
99+
------
100+
ValueError
101+
If the sizes of the forecast data do not match the
102+
:py:attr:`~climada.hazard.base.Hazard.event_id`
103+
"""
104+
num_entries = len(self.event_id)
105+
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
106+
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")

climada/hazard/io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,10 @@ def write_hdf5(self, file_name, todense=False):
917917
# Centroids have their own write_hdf5 method,
918918
# which is invoked at the end of this method (s.b.)
919919
continue
920+
elif var_name == "lead_time":
921+
hf_data.create_dataset(
922+
var_name, data=var_val.astype("timedelta64[ns]").astype("int64")
923+
)
920924
elif isinstance(var_val, sparse.csr_matrix):
921925
if todense:
922926
hf_data.create_dataset(var_name, data=var_val.toarray())
@@ -987,7 +991,11 @@ def from_hdf5(cls, file_name):
987991
continue
988992
if var_name == "centroids":
989993
continue
990-
if isinstance(var_val, np.ndarray) and var_val.ndim == 1:
994+
if var_name == "lead_time":
995+
hazard_kwargs[var_name] = np.array(hf_data.get(var_name)).astype(
996+
"timedelta64[ns]"
997+
)
998+
elif isinstance(var_val, np.ndarray) and var_val.ndim == 1:
991999
hazard_kwargs[var_name] = np.array(hf_data.get(var_name))
9921000
elif isinstance(var_val, sparse.csr_matrix):
9931001
hf_csr = hf_data.get(var_name)

climada/hazard/test/test_forecast.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def hazard(haz_kwargs):
4141

4242

4343
@pytest.fixture
44-
def lead_time():
45-
return pd.timedelta_range("1h", periods=6).to_numpy()
44+
def lead_time(haz_kwargs):
45+
return pd.timedelta_range("1h", periods=len(haz_kwargs["event_id"])).to_numpy()
4646

4747

4848
@pytest.fixture
49-
def member():
50-
return np.arange(6)
49+
def member(haz_kwargs):
50+
return np.arange(len(haz_kwargs["event_id"]))
5151

5252

5353
@pytest.fixture
@@ -78,6 +78,13 @@ def test_init_hazard_forecast(haz_fc, member, lead_time, haz_kwargs):
7878
assert_hazard_kwargs(haz_fc, **haz_kwargs)
7979

8080

81+
def test_init_hazard_forecast_error(hazard, member, lead_time, haz_kwargs):
82+
with pytest.raises(ValueError, match="Forecast.lead_time"):
83+
HazardForecast(lead_time=lead_time[:-2], member=member, **haz_kwargs)
84+
with pytest.raises(ValueError, match="Forecast.member"):
85+
HazardForecast.from_hazard(hazard, lead_time=lead_time, member=member[1:])
86+
87+
8188
def test_from_hazard(lead_time, member, hazard, haz_kwargs):
8289
haz_fc_from_haz = HazardForecast.from_hazard(
8390
hazard, lead_time=lead_time, member=member
@@ -100,10 +107,60 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member):
100107
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
101108

102109

103-
def test_hazard_forecast_select(haz_fc, lead_time, member):
110+
@pytest.mark.parametrize(
111+
"var, var_select",
112+
[("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")],
113+
)
114+
def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_select):
104115
"""Check if Hazard.select works on the derived class"""
105-
haz_fc_select = haz_fc.select(event_id=[4, 1])
106-
# NOTE: Events keep their original order
107-
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([3, 0])])
108-
npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])])
109-
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])])
116+
117+
select_mask = np.array([3, 2])
118+
ordered_select_mask = np.array([3, 2])
119+
if var == "date":
120+
# Date needs to be a valid delta
121+
select_mask = np.array([2, 3])
122+
ordered_select_mask = np.array([2, 3])
123+
124+
var_value = np.array(haz_kwargs[var])[select_mask]
125+
# event_name is a list, convert to numpy array for indexing
126+
haz_fc_sel = haz_fc.select(**{var_select: var_value})
127+
# Note: order is preserved
128+
npt.assert_array_equal(
129+
haz_fc_sel.event_id,
130+
haz_fc.event_id[ordered_select_mask],
131+
)
132+
npt.assert_array_equal(
133+
haz_fc_sel.event_name,
134+
np.array(haz_fc.event_name)[ordered_select_mask],
135+
)
136+
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
137+
npt.assert_array_equal(haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask])
138+
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
139+
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
140+
npt.assert_array_equal(
141+
haz_fc_sel.intensity.todense(),
142+
haz_fc.intensity.todense()[ordered_select_mask],
143+
)
144+
npt.assert_array_equal(
145+
haz_fc_sel.fraction.todense(),
146+
haz_fc.fraction.todense()[ordered_select_mask],
147+
)
148+
149+
assert haz_fc_sel.centroids == haz_fc.centroids
150+
151+
152+
def test_write_read_hazard_forecast(haz_fc, tmp_path):
153+
154+
file_name = tmp_path / "test_hazard_forecast.h5"
155+
156+
haz_fc.write_hdf5(file_name)
157+
haz_fc_read = HazardForecast.from_hdf5(file_name)
158+
159+
assert haz_fc_read.lead_time.dtype.kind == np.dtype("timedelta64").kind
160+
161+
for key in haz_fc.__dict__.keys():
162+
if key in ["intensity", "fraction"]:
163+
(haz_fc.__dict__[key] != haz_fc_read.__dict__[key]).nnz == 0
164+
else:
165+
# npt.assert_array_equal also works for comparing int, float or list
166+
npt.assert_array_equal(haz_fc.__dict__[key], haz_fc_read.__dict__[key])

0 commit comments

Comments
 (0)