Skip to content

Commit 52edc45

Browse files
authored
Add HazardForecast.concat (#1184)
* Add HazardForecast.concat
1 parent 302be63 commit 52edc45

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

climada/hazard/forecast.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ def _check_sizes(self):
105105
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
106106
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
107107

108+
@classmethod
109+
def concat(cls, haz_list: list):
110+
"""Concatenate multiple HazardForecast instances and return a new object"""
111+
if len(haz_list) == 0:
112+
return cls()
113+
hazard = Hazard.concat(haz_list)
114+
lead_time = np.concatenate(tuple(haz.lead_time for haz in haz_list))
115+
member = np.concatenate(tuple(haz.member for haz in haz_list))
116+
return cls.from_hazard(hazard, lead_time=lead_time, member=member)
117+
108118
def select(
109119
self,
110120
member=None,

climada/hazard/test/test_forecast.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,37 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs):
9595
assert_hazard_kwargs(haz_fc_from_haz, **haz_kwargs)
9696

9797

98-
@pytest.mark.skip("Concat from base class does not work")
99-
def test_hazard_forecast_concat(haz_fc, lead_time, member):
100-
haz_fc1 = haz_fc.select(event_id=[1, 2])
101-
haz_fc2 = haz_fc.select(event_id=[3, 4])
102-
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2])
103-
assert isinstance(haz_fc_concat, HazardForecast)
104-
npt.assert_array_equal(
105-
haz_fc_concat.lead_time, np.concatenate([lead_time, lead_time])
106-
)
107-
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
98+
class TestHazardForecastConcat:
99+
100+
def test_concat(self, haz_fc, lead_time, member, haz_kwargs):
101+
haz_fc1 = haz_fc.select(event_id=[3])
102+
haz_fc2 = HazardForecast(
103+
haz_type=haz_kwargs["haz_type"], frequency_unit=haz_kwargs["frequency_unit"]
104+
) # Empty hazard
105+
haz_fc3 = haz_fc.select(event_id=[1, 2])
106+
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2, haz_fc3])
107+
assert isinstance(haz_fc_concat, HazardForecast)
108+
assert haz_fc_concat.size == 3
109+
npt.assert_array_equal(
110+
haz_fc_concat.lead_time, np.concatenate((lead_time[2:3], lead_time[0:2]))
111+
)
112+
npt.assert_array_equal(
113+
haz_fc_concat.member, np.concatenate((member[2:3], member[0:2]))
114+
)
115+
npt.assert_array_equal(haz_fc_concat.event_id, [3, 1, 2])
116+
117+
def test_empty_list(self):
118+
haz_concat = HazardForecast.concat([])
119+
assert isinstance(haz_concat, HazardForecast)
120+
assert haz_concat.size == 0
121+
npt.assert_array_equal(haz_concat.lead_time, [])
122+
npt.assert_array_equal(haz_concat.event_id, [])
123+
124+
def test_type_fail(self, haz_fc, hazard):
125+
with pytest.raises(TypeError, match="different classes"):
126+
HazardForecast.concat([haz_fc, hazard])
127+
with pytest.raises(TypeError, match="different classes"):
128+
Hazard.concat([haz_fc, hazard])
108129

109130

110131
class TestSelect:

climada/util/forecast.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ def __init__(
5252
"""
5353

5454
self.lead_time = (
55-
np.asarray(lead_time) if lead_time is not None else np.array([])
55+
np.asarray(lead_time)
56+
if lead_time is not None
57+
else np.array([], dtype="timedelta64[ns]")
58+
)
59+
self.member = (
60+
np.asarray(member) if member is not None else np.array([], dtype="int")
5661
)
57-
self.member = np.asarray(member) if member is not None else np.array([])
5862
super().__init__(**kwargs)
5963

6064
def idx_member(self, member: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)