Skip to content

Commit 928751d

Browse files
committed
Add HazardForecast.concat
1 parent 64da484 commit 928751d

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

climada/hazard/forecast.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import logging
23+
from typing import Self
2324

2425
import numpy as np
2526

@@ -104,3 +105,13 @@ def _check_sizes(self):
104105
num_entries = len(self.event_id)
105106
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
106107
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
108+
109+
@classmethod
110+
def concat(cls, haz_list: list[Self]) -> Self:
111+
"""Concatenate multiple HazardForecast instances and return a new object"""
112+
if len(haz_list) == 0:
113+
return cls()
114+
hazard = Hazard.concat(haz_list)
115+
lead_time = np.concatenate(tuple(haz.lead_time for haz in haz_list))
116+
member = np.concatenate(tuple(haz.member for haz in haz_list))
117+
return cls.from_hazard(hazard, lead_time=lead_time, member=member)

climada/hazard/test/test_forecast.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,37 @@ def test_from_hazard(lead_time, member, hazard, haz_kwargs):
9595
assert_hazard_kwargs(haz_fc_from_haz, **haz_kwargs)
9696

9797

98-
@pytest.mark.skip("Concat from base class does not work")
99-
def test_hazard_forecast_concat(haz_fc, lead_time, member):
100-
haz_fc1 = haz_fc.select(event_id=[1, 2])
101-
haz_fc2 = haz_fc.select(event_id=[3, 4])
102-
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2])
103-
assert isinstance(haz_fc_concat, HazardForecast)
104-
npt.assert_array_equal(
105-
haz_fc_concat.lead_time, np.concatenate([lead_time, lead_time])
106-
)
107-
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
98+
class TestHazardForecastConcat:
99+
100+
def test_concat(self, haz_fc, lead_time, member, haz_kwargs):
101+
haz_fc1 = haz_fc.select(event_id=[3])
102+
haz_fc2 = HazardForecast(
103+
haz_type=haz_kwargs["haz_type"], frequency_unit=haz_kwargs["frequency_unit"]
104+
) # Empty hazard
105+
haz_fc3 = haz_fc.select(event_id=[1, 2])
106+
haz_fc_concat = HazardForecast.concat([haz_fc1, haz_fc2, haz_fc3])
107+
assert isinstance(haz_fc_concat, HazardForecast)
108+
assert haz_fc_concat.size == 3
109+
npt.assert_array_equal(
110+
haz_fc_concat.lead_time, np.concatenate((lead_time[2:3], lead_time[0:2]))
111+
)
112+
npt.assert_array_equal(
113+
haz_fc_concat.member, np.concatenate((member[2:3], member[0:2]))
114+
)
115+
npt.assert_array_equal(haz_fc_concat.event_id, [3, 1, 2])
116+
117+
def test_empty_list(self):
118+
haz_concat = HazardForecast.concat([])
119+
assert isinstance(haz_concat, HazardForecast)
120+
assert haz_concat.size == 0
121+
npt.assert_array_equal(haz_concat.lead_time, [])
122+
npt.assert_array_equal(haz_concat.event_id, [])
123+
124+
def test_type_fail(self, haz_fc, hazard):
125+
with pytest.raises(TypeError, match="different classes"):
126+
HazardForecast.concat([haz_fc, hazard])
127+
with pytest.raises(TypeError, match="different classes"):
128+
Hazard.concat([haz_fc, hazard])
108129

109130

110131
@pytest.mark.parametrize(

climada/util/forecast.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def __init__(
5252
"""
5353

5454
self.lead_time = (
55-
np.asarray(lead_time) if lead_time is not None else np.array([])
55+
np.asarray(lead_time)
56+
if lead_time is not None
57+
else np.array([], dtype="timedelta64[ns]")
58+
)
59+
self.member = (
60+
np.asarray(member) if member is not None else np.array([], dtype="int")
5661
)
57-
self.member = np.asarray(member) if member is not None else np.array([])
5862
super().__init__(**kwargs)

0 commit comments

Comments
 (0)