Skip to content

Commit 4b7d4bc

Browse files
author
Chahan Kropf
committed
Merge remote-tracking branch 'origin/forecast-class' into forecast/select_extended_tests
2 parents dd78fac + 48b6d40 commit 4b7d4bc

File tree

5 files changed

+81
-12
lines changed

5 files changed

+81
-12
lines changed

climada/engine/impact_forecast.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
from ..util import log_level
27+
from ..util.checker import size
2728
from ..util.forecast import Forecast
2829
from .impact import Impact
2930

@@ -51,8 +52,8 @@ def __init__(
5152
impact_kwargs
5253
Keyword-arguments passed to ~:py:class`climada.engine.impact.Impact`.
5354
"""
54-
# TODO: Maybe assert array lengths?
5555
super().__init__(lead_time=lead_time, member=member, **impact_kwargs)
56+
self._check_sizes()
5657

5758
@classmethod
5859
def from_impact(
@@ -88,3 +89,16 @@ def from_impact(
8889
imp_mat=impact.imp_mat,
8990
haz_type=impact.haz_type,
9091
)
92+
93+
def _check_sizes(self):
94+
"""Check sizes of forecast data vs. impact data.
95+
96+
Raises
97+
------
98+
ValueError
99+
If the sizes of the forecast data do not match the
100+
:py:attr:`~climada.engine.impact.Impact.event_id`
101+
"""
102+
num_entries = len(self.event_id)
103+
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
104+
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")

climada/engine/test/test_impact_forecast.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def impact(impact_kwargs):
4141

4242

4343
@pytest.fixture
44-
def lead_time():
45-
return pd.timedelta_range(start="1 day", periods=6).to_numpy()
44+
def lead_time(impact_kwargs):
45+
return pd.timedelta_range(
46+
start="1 day", periods=len(impact_kwargs["event_id"])
47+
).to_numpy()
4648

4749

4850
@pytest.fixture
49-
def member():
50-
return np.arange(6)
51+
def member(impact_kwargs):
52+
return np.arange(len(impact_kwargs["event_id"]))
5153

5254

5355
@pytest.fixture
@@ -76,6 +78,12 @@ def test_impact_forecast_init(self, impact_kwargs, lead_time, member):
7678
npt.assert_array_equal(forecast1.member, member)
7779
self.assert_impact_kwargs(forecast1, **impact_kwargs)
7880

81+
def test_impact_forecast_init_error(self, impact, impact_kwargs, lead_time, member):
82+
with pytest.raises(ValueError, match="Forecast.lead_time"):
83+
ImpactForecast(lead_time=lead_time[:-2], member=member, **impact_kwargs)
84+
with pytest.raises(ValueError, match="Forecast.member"):
85+
ImpactForecast.from_impact(impact, lead_time=lead_time, member=member[1:])
86+
7987
def test_impact_forecast_from_impact(
8088
self, impact_forecast, impact_kwargs, lead_time, member
8189
):

climada/hazard/forecast.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
import numpy as np
2525

26-
from climada.hazard.base import Hazard
27-
from climada.util.forecast import Forecast
26+
from ..util.checker import size
27+
from ..util.forecast import Forecast
28+
from .base import Hazard
2829

2930
LOGGER = logging.getLogger(__name__)
3031

@@ -52,6 +53,7 @@ def __init__(
5253
py:meth`~climada.hazard.base.Hazard.__init__` for details.
5354
"""
5455
super().__init__(lead_time=lead_time, member=member, **hazard_kwargs)
56+
self._check_sizes()
5557

5658
@classmethod
5759
def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray):
@@ -89,3 +91,16 @@ def from_hazard(cls, hazard: Hazard, lead_time: np.ndarray, member: np.ndarray):
8991
intensity=hazard.intensity,
9092
fraction=hazard.fraction,
9193
)
94+
95+
def _check_sizes(self):
96+
"""Check sizes of forecast data vs. hazard data.
97+
98+
Raises
99+
------
100+
ValueError
101+
If the sizes of the forecast data do not match the
102+
:py:attr:`~climada.hazard.base.Hazard.event_id`
103+
"""
104+
num_entries = len(self.event_id)
105+
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
106+
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")

climada/hazard/io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,10 @@ def write_hdf5(self, file_name, todense=False):
917917
# Centroids have their own write_hdf5 method,
918918
# which is invoked at the end of this method (s.b.)
919919
continue
920+
elif var_name == "lead_time":
921+
hf_data.create_dataset(
922+
var_name, data=var_val.astype("timedelta64[ns]").astype("int64")
923+
)
920924
elif isinstance(var_val, sparse.csr_matrix):
921925
if todense:
922926
hf_data.create_dataset(var_name, data=var_val.toarray())
@@ -987,7 +991,11 @@ def from_hdf5(cls, file_name):
987991
continue
988992
if var_name == "centroids":
989993
continue
990-
if isinstance(var_val, np.ndarray) and var_val.ndim == 1:
994+
if var_name == "lead_time":
995+
hazard_kwargs[var_name] = np.array(hf_data.get(var_name)).astype(
996+
"timedelta64[ns]"
997+
)
998+
elif isinstance(var_val, np.ndarray) and var_val.ndim == 1:
991999
hazard_kwargs[var_name] = np.array(hf_data.get(var_name))
9921000
elif isinstance(var_val, sparse.csr_matrix):
9931001
hf_csr = hf_data.get(var_name)

climada/hazard/test/test_forecast.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def hazard(haz_kwargs):
4141

4242

4343
@pytest.fixture
44-
def lead_time():
45-
return pd.timedelta_range("1h", periods=6).to_numpy()
44+
def lead_time(haz_kwargs):
45+
return pd.timedelta_range("1h", periods=len(haz_kwargs["event_id"])).to_numpy()
4646

4747

4848
@pytest.fixture
49-
def member():
50-
return np.arange(6)
49+
def member(haz_kwargs):
50+
return np.arange(len(haz_kwargs["event_id"]))
5151

5252

5353
@pytest.fixture
@@ -78,6 +78,13 @@ def test_init_hazard_forecast(haz_fc, member, lead_time, haz_kwargs):
7878
assert_hazard_kwargs(haz_fc, **haz_kwargs)
7979

8080

81+
def test_init_hazard_forecast_error(hazard, member, lead_time, haz_kwargs):
82+
with pytest.raises(ValueError, match="Forecast.lead_time"):
83+
HazardForecast(lead_time=lead_time[:-2], member=member, **haz_kwargs)
84+
with pytest.raises(ValueError, match="Forecast.member"):
85+
HazardForecast.from_hazard(hazard, lead_time=lead_time, member=member[1:])
86+
87+
8188
def test_from_hazard(lead_time, member, hazard, haz_kwargs):
8289
haz_fc_from_haz = HazardForecast.from_hazard(
8390
hazard, lead_time=lead_time, member=member
@@ -140,3 +147,20 @@ def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_
140147
)
141148

142149
assert haz_fc_sel.centroids == haz_fc.centroids
150+
151+
152+
def test_write_read_hazard_forecast(haz_fc, tmp_path):
153+
154+
file_name = tmp_path / "test_hazard_forecast.h5"
155+
156+
haz_fc.write_hdf5(file_name)
157+
haz_fc_read = HazardForecast.from_hdf5(file_name)
158+
159+
assert haz_fc_read.lead_time.dtype.kind == np.dtype("timedelta64").kind
160+
161+
for key in haz_fc.__dict__.keys():
162+
if key in ["intensity", "fraction"]:
163+
(haz_fc.__dict__[key] != haz_fc_read.__dict__[key]).nnz == 0
164+
else:
165+
# npt.assert_array_equal also works for comparing int, float or list
166+
npt.assert_array_equal(haz_fc.__dict__[key], haz_fc_read.__dict__[key])

0 commit comments

Comments
 (0)