Skip to content

Commit d571bb7

Browse files
committed
Merge branch 'forecast-class' into impactCalc_block_nonsense_attrs
2 parents 4b5ae95 + a386d22 commit d571bb7

File tree

5 files changed

+91
-41
lines changed

5 files changed

+91
-41
lines changed

climada/engine/impact_forecast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def __init__(
4545
Parameters
4646
----------
4747
lead_time : np.ndarray, optional
48-
The lead time associated with each event entry
48+
The lead time associated with each event entry, given as timedelta64 type
4949
member : np.ndarray, optional
50-
The ensemble member associated with each event entry
50+
The ensemble member associated with each event entry, given as integers
5151
impact_kwargs
5252
Keyword-arguments passed to ~:py:class`climada.engine.impact.Impact`.
5353
"""
@@ -65,9 +65,9 @@ def from_impact(
6565
impact : climada.engine.impact.Impact
6666
The impact object whose data to use in the forecast object
6767
lead_time : np.ndarray, optional
68-
The lead time associated with each event entry
68+
The lead time associated with each event entry, given as timedelta64 type
6969
member : np.ndarray, optional
70-
The ensemble member associated with each event entry
70+
The ensemble member associated with each event entry, given as integers
7171
"""
7272
with log_level("WARNING", "climada.engine.impact"):
7373
return cls(

climada/engine/test/test_impact_forecast.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,35 +40,66 @@ def impact(impact_kwargs):
4040
return Impact(**impact_kwargs)
4141

4242

43-
def assert_impact_kwargs(impact: Impact, **kwargs):
44-
for key, value in kwargs.items():
45-
attr = getattr(impact, key)
46-
if isinstance(value, (np.ndarray, list)):
47-
npt.assert_array_equal(attr, value)
48-
elif isinstance(value, csr_matrix):
49-
npt.assert_array_equal(attr.todense(), value.todense())
50-
else:
51-
assert attr == value
43+
@pytest.fixture
44+
def lead_time():
45+
return pd.timedelta_range(start="1 day", periods=6).to_numpy()
5246

5347

54-
class TestImpactForecastInit:
55-
lead_time = pd.date_range("2000-01-01", "2000-01-02", periods=6).to_numpy()
56-
member = np.arange(6)
48+
@pytest.fixture
49+
def member():
50+
return np.arange(6)
51+
52+
53+
@pytest.fixture
54+
def impact_forecast(impact, lead_time, member):
55+
return ImpactForecast.from_impact(impact, lead_time=lead_time, member=member)
56+
5757

58-
def test_impact_forecast_init(self, impact_kwargs):
58+
class TestImpactForecastInit:
59+
def assert_impact_kwargs(self, impact: Impact, **kwargs):
60+
for key, value in kwargs.items():
61+
attr = getattr(impact, key)
62+
if isinstance(value, (np.ndarray, list)):
63+
npt.assert_array_equal(attr, value)
64+
elif isinstance(value, csr_matrix):
65+
npt.assert_array_equal(attr.todense(), value.todense())
66+
else:
67+
assert attr == value
68+
69+
def test_impact_forecast_init(self, impact_kwargs, lead_time, member):
5970
forecast1 = ImpactForecast(
60-
lead_time=self.lead_time,
61-
member=self.member,
71+
lead_time=lead_time,
72+
member=member,
6273
**impact_kwargs,
6374
)
64-
npt.assert_array_equal(forecast1.lead_time, self.lead_time)
65-
npt.assert_array_equal(forecast1.member, self.member)
66-
assert_impact_kwargs(forecast1, **impact_kwargs)
67-
68-
def test_impact_forecast_from_impact(self, impact, impact_kwargs):
69-
forecast = ImpactForecast.from_impact(
70-
impact, lead_time=self.lead_time, member=self.member
71-
)
72-
npt.assert_array_equal(forecast.lead_time, self.lead_time)
73-
npt.assert_array_equal(forecast.member, self.member)
74-
assert_impact_kwargs(forecast, **impact_kwargs)
75+
npt.assert_array_equal(forecast1.lead_time, lead_time)
76+
npt.assert_array_equal(forecast1.member, member)
77+
self.assert_impact_kwargs(forecast1, **impact_kwargs)
78+
79+
def test_impact_forecast_from_impact(
80+
self, impact_forecast, impact_kwargs, lead_time, member
81+
):
82+
npt.assert_array_equal(impact_forecast.lead_time, lead_time)
83+
npt.assert_array_equal(impact_forecast.member, member)
84+
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
85+
86+
87+
def test_impact_forecast_select(impact_forecast, lead_time, member, impact_kwargs):
88+
"""Check if Impact.select works on the derived class"""
89+
event_ids = impact_kwargs["event_id"][np.array([2, 0])]
90+
impact_fc = impact_forecast.select(event_ids=event_ids)
91+
# NOTE: Events keep their original order
92+
npt.assert_array_equal(
93+
impact_fc.event_id, impact_forecast.event_id[np.array([0, 2])]
94+
)
95+
npt.assert_array_equal(impact_fc.member, member[np.array([0, 2])])
96+
npt.assert_array_equal(impact_fc.lead_time, lead_time[np.array([0, 2])])
97+
98+
99+
@pytest.mark.skip("Concat from base class does not work")
100+
def test_impact_forecast_concat(impact_forecast, member):
101+
"""Check if Impact.concat works on the derived class"""
102+
impact_fc = ImpactForecast.concat(
103+
[impact_forecast, impact_forecast], reset_event_ids=True
104+
)
105+
npt.assert_array_equal(impact_fc.member, np.concatenate([member, member]))

climada/hazard/test/test_forecast.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
from climada.hazard.forecast import HazardForecast
3030
from climada.hazard.test.test_base import hazard_kwargs
3131

32-
# --- Examples for fixtures and test organization --- #
33-
3432

3533
@pytest.fixture
3634
def haz_kwargs():
@@ -88,3 +86,24 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs):
8886
npt.assert_array_equal(haz_fc_from_haz.lead_time, lead_time)
8987
npt.assert_array_equal(haz_fc_from_haz.member, member)
9088
assert_hazard_kwargs(haz_fc_from_haz, **haz_kwargs)
89+
90+
91+
@pytest.mark.skip("Concat from base class does not work")
92+
def test_hazard_forecast_concat(haz_fc, lead_time, member):
93+
haz_fc1 = haz_fc.select(event_id=[1, 2])
94+
haz_fc2 = haz_fc.select(event_id=[3, 4])
95+
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2])
96+
assert isinstance(haz_fc_concat, HazardForecast)
97+
npt.assert_array_equal(
98+
haz_fc_concat.lead_time, np.concatenate([lead_time, lead_time])
99+
)
100+
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
101+
102+
103+
def test_hazard_forecast_select(haz_fc, lead_time, member):
104+
"""Check if Hazard.select works on the derived class"""
105+
haz_fc_select = haz_fc.select(event_id=[4, 1])
106+
# NOTE: Events keep their original order
107+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([3, 0])])
108+
npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])])
109+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])])

climada/util/forecast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class Forecast:
2828
Attributes
2929
----------
3030
lead_time : np.ndarray
31-
Array of forecast lead times, given as datetime64 objects.
32-
Represents the time points for which forecasts are made.
31+
Array of forecast lead times, given as timedelta64 objects.
32+
Represents the lead times of the forecasts.
3333
member : np.ndarray
3434
Array of ensemble member identifiers, given as integers.
3535
Represents different forecast ensemble members.

climada/util/test/test_forecast.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import numpy as np
2323
import numpy.testing as npt
24+
import pandas as pd
2425

2526
from climada.util.forecast import Forecast
2627

@@ -34,19 +35,18 @@ def test_forecast_init():
3435
forecast = Forecast(member=np.array([1, 2]))
3536
npt.assert_array_equal(forecast.member, np.array([1, 2]), strict=True)
3637

37-
forecast = Forecast(lead_time=np.array([1, 2]))
38-
npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True)
38+
forecast = Forecast(lead_time=np.array([6, 12], dtype="timedelta64[h]"))
39+
npt.assert_array_equal(
40+
forecast.lead_time, np.array([6, 12], dtype="timedelta64[h]"), strict=True
41+
)
3942

4043
forecast = Forecast(lead_time=np.array([1, 2]), member=[3, 4])
4144
npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True)
4245
npt.assert_array_equal(forecast.member, np.array([3, 4]), strict=True)
4346
assert isinstance(forecast.member, np.ndarray)
4447

4548
# Test with datetime64 including seconds
46-
lead_times_seconds = np.array(
47-
["2024-01-01T00:00:00", "2024-01-01T00:01:00", "2024-01-01"],
48-
dtype="datetime64[s]",
49-
)
49+
lead_times_seconds = pd.timedelta_range(start="1 day", periods=4).to_numpy()
5050
forecast = Forecast(lead_time=lead_times_seconds, member=[1, 2, 3])
5151
npt.assert_array_equal(forecast.lead_time, lead_times_seconds, strict=True)
52-
assert forecast.lead_time.dtype == np.dtype("datetime64[s]")
52+
assert forecast.lead_time.dtype == np.dtype("timedelta64[ns]")

0 commit comments

Comments
 (0)