Skip to content

Commit 302be63

Browse files
ValentinGebhartelianekoblerpeanutfun
authored
Add select method to HazardForecast (#1185)
* add HazardForecast.select and extent test --------- Co-authored-by: Eliane Kobler <[email protected]> Co-authored-by: Lukas Riedel <[email protected]>
1 parent b17bcc9 commit 302be63

File tree

2 files changed

+148
-39
lines changed

2 files changed

+148
-39
lines changed

climada/hazard/forecast.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,64 @@ def _check_sizes(self):
104104
num_entries = len(self.event_id)
105105
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
106106
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
107+
108+
def select(
109+
self,
110+
member=None,
111+
lead_time=None,
112+
event_names=None,
113+
event_id=None,
114+
date=None,
115+
orig=None,
116+
reg_id=None,
117+
extent=None,
118+
reset_frequency=False,
119+
):
120+
"""Select entries based on the parameters and return a new instance.
121+
122+
The selection will contain the intersection of all given parameters.
123+
124+
Parameters
125+
----------
126+
member : Sequence of ints
127+
Ensemble members to select
128+
lead_time : Sequence of numpy.timedelta64
129+
Lead times to select
130+
131+
Returns
132+
-------
133+
HazardForecast
134+
135+
See Also
136+
--------
137+
:py:meth:`~climada.hazard.base.Hazard.select`
138+
"""
139+
if member is not None or lead_time is not None:
140+
mask_member = (
141+
self.idx_member(member)
142+
if member is not None
143+
else np.full_like(self.member, True, dtype=bool)
144+
)
145+
mask_lead_time = (
146+
self.idx_lead_time(lead_time)
147+
if lead_time is not None
148+
else np.full_like(self.lead_time, True, dtype=bool)
149+
)
150+
event_id_from_forecast_mask = np.asarray(self.event_id)[
151+
(mask_member & mask_lead_time)
152+
]
153+
event_id = (
154+
np.intersect1d(event_id, event_id_from_forecast_mask)
155+
if event_id is not None
156+
else event_id_from_forecast_mask
157+
)
158+
159+
return super().select(
160+
event_names=event_names,
161+
event_id=event_id,
162+
date=date,
163+
orig=orig,
164+
reg_id=reg_id,
165+
extent=extent,
166+
reset_frequency=reset_frequency,
167+
)

climada/hazard/test/test_forecast.py

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -107,46 +107,94 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member):
107107
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
108108

109109

110-
@pytest.mark.parametrize(
111-
"var, var_select",
112-
[("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")],
113-
)
114-
def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_select):
115-
"""Check if Hazard.select works on the derived class"""
116-
117-
select_mask = np.array([3, 2])
118-
ordered_select_mask = np.array([3, 2])
119-
if var == "date":
120-
# Date needs to be a valid delta
121-
select_mask = np.array([2, 3])
122-
ordered_select_mask = np.array([2, 3])
123-
124-
var_value = np.array(haz_kwargs[var])[select_mask]
125-
# event_name is a list, convert to numpy array for indexing
126-
haz_fc_sel = haz_fc.select(**{var_select: var_value})
127-
# Note: order is preserved
128-
npt.assert_array_equal(
129-
haz_fc_sel.event_id,
130-
haz_fc.event_id[ordered_select_mask],
131-
)
132-
npt.assert_array_equal(
133-
haz_fc_sel.event_name,
134-
np.array(haz_fc.event_name)[ordered_select_mask],
135-
)
136-
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
137-
npt.assert_array_equal(haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask])
138-
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
139-
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
140-
npt.assert_array_equal(
141-
haz_fc_sel.intensity.todense(),
142-
haz_fc.intensity.todense()[ordered_select_mask],
143-
)
144-
npt.assert_array_equal(
145-
haz_fc_sel.fraction.todense(),
146-
haz_fc.fraction.todense()[ordered_select_mask],
147-
)
110+
class TestSelect:
148111

149-
assert haz_fc_sel.centroids == haz_fc.centroids
112+
@pytest.mark.parametrize(
113+
"var, var_select",
114+
[("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")],
115+
)
116+
def test_base_class_select(
117+
self, haz_fc, lead_time, member, haz_kwargs, var, var_select
118+
):
119+
"""Check if Hazard.select works on the derived class"""
120+
121+
select_mask = np.array([3, 2])
122+
ordered_select_mask = np.array([3, 2])
123+
if var == "date":
124+
# Date needs to be a valid delta
125+
select_mask = np.array([2, 3])
126+
ordered_select_mask = np.array([2, 3])
127+
128+
var_value = np.array(haz_kwargs[var])[select_mask]
129+
# event_name is a list, convert to numpy array for indexing
130+
haz_fc_sel = haz_fc.select(**{var_select: var_value})
131+
# Note: order is preserved
132+
npt.assert_array_equal(
133+
haz_fc_sel.event_id,
134+
haz_fc.event_id[ordered_select_mask],
135+
)
136+
npt.assert_array_equal(
137+
haz_fc_sel.event_name,
138+
np.array(haz_fc.event_name)[ordered_select_mask],
139+
)
140+
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
141+
npt.assert_array_equal(
142+
haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask]
143+
)
144+
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
145+
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
146+
npt.assert_array_equal(
147+
haz_fc_sel.intensity.todense(),
148+
haz_fc.intensity.todense()[ordered_select_mask],
149+
)
150+
npt.assert_array_equal(
151+
haz_fc_sel.fraction.todense(),
152+
haz_fc.fraction.todense()[ordered_select_mask],
153+
)
154+
155+
assert haz_fc_sel.centroids == haz_fc.centroids
156+
157+
def test_derived_select_single(self, haz_fc, lead_time, member):
158+
haz_fc_select = haz_fc.select(member=[3, 0])
159+
idx = np.array([0, 3])
160+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx])
161+
npt.assert_array_equal(haz_fc_select.member, member[idx])
162+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx])
163+
164+
haz_fc_select = haz_fc.select(lead_time=lead_time[np.array([3, 0])])
165+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx])
166+
npt.assert_array_equal(haz_fc_select.member, member[idx])
167+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx])
168+
169+
def test_derived_select_intersections(self, haz_fc, lead_time, member, haz_kwargs):
170+
haz_fc_select = haz_fc.select(event_id=[1, 4], member=[0, 1, 2])
171+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0])])
172+
173+
haz_fc_select = haz_fc.select(
174+
event_id=[1, 2, 4], member=[0, 1, 2], lead_time=lead_time[1:3]
175+
)
176+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([1])])
177+
178+
# Test "outer"
179+
haz_fc2 = HazardForecast(
180+
lead_time=lead_time, member=np.zeros_like(member, dtype="int"), **haz_kwargs
181+
)
182+
haz_fc_select = haz_fc2.select(event_id=[1, 2, 4], member=[0])
183+
npt.assert_array_equal(haz_fc_select.event_id, [1, 2, 4])
184+
npt.assert_array_equal(haz_fc_select.member, [0, 0, 0])
185+
186+
def test_derived_select_null(self, haz_fc, haz_kwargs):
187+
haz_fc_select = haz_fc.select()
188+
assert_hazard_kwargs(haz_fc_select, **haz_kwargs)
189+
190+
with pytest.raises(IndexError):
191+
haz_fc.select(event_id=[-1])
192+
with pytest.raises(IndexError):
193+
haz_fc.select(member=[-1])
194+
with pytest.raises(IndexError):
195+
haz_fc.select(
196+
lead_time=[np.timedelta64("2", "Y").astype("timedelta64[ns]")]
197+
)
150198

151199

152200
def test_write_read_hazard_forecast(haz_fc, tmp_path):

0 commit comments

Comments
 (0)