Skip to content

Commit b17bcc9

Browse files
Add select to ImpactForecast (#1188)
* Add ImpactForecast.select --------- Co-authored-by: Lukas Riedel <[email protected]>
1 parent 9d6fef9 commit b17bcc9

File tree

2 files changed

+160
-48
lines changed

2 files changed

+160
-48
lines changed

climada/engine/impact_forecast.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,59 @@ def _check_sizes(self):
184184
num_entries = len(self.event_id)
185185
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
186186
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
187+
188+
def select(
189+
self,
190+
event_ids=None,
191+
event_names=None,
192+
dates=None,
193+
coord_exp=None,
194+
reset_frequency=False,
195+
member=None,
196+
lead_time=None,
197+
):
198+
"""Select entries based on the parameters and return a new instance.
199+
The selection will contain the intersection of all given parameters.
200+
201+
Parameters
202+
----------
203+
member : Sequence of ints
204+
Ensemble members to select
205+
lead_time : Sequence of numpy.timedelta64
206+
Lead times to select
207+
208+
Returns
209+
-------
210+
ImpactForecast
211+
212+
See Also
213+
--------
214+
:py:meth:`~climada.engine.impact.Impact.select`
215+
"""
216+
if member is not None or lead_time is not None:
217+
mask_member = (
218+
self.idx_member(member)
219+
if member is not None
220+
else np.full_like(self.member, True, dtype=bool)
221+
)
222+
mask_lead_time = (
223+
self.idx_lead_time(lead_time)
224+
if lead_time is not None
225+
else np.full_like(self.lead_time, True, dtype=bool)
226+
)
227+
event_id_from_forecast_mask = np.asarray(self.event_id)[
228+
(mask_member & mask_lead_time)
229+
]
230+
event_ids = (
231+
np.intersect1d(event_ids, event_id_from_forecast_mask)
232+
if event_ids is not None
233+
else event_id_from_forecast_mask
234+
)
235+
236+
return super().select(
237+
event_ids=event_ids,
238+
event_names=event_names,
239+
dates=dates,
240+
coord_exp=coord_exp,
241+
reset_frequency=reset_frequency,
242+
)

climada/engine/test/test_impact_forecast.py

Lines changed: 104 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -92,58 +92,114 @@ def test_impact_forecast_from_impact(
9292
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
9393

9494

95-
@pytest.mark.parametrize(
96-
"var, var_select",
97-
[("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")],
98-
)
99-
def test_impact_forecast_select_events(
100-
impact_forecast, lead_time, member, impact_kwargs, var, var_select
101-
):
102-
"""Check if Impact.select works on the derived class"""
103-
select_mask = np.array([2, 1])
104-
ordered_select_mask = np.array([1, 2])
105-
if var == "date":
106-
# Date needs to be a valid delta
107-
select_mask = np.array([1, 2])
108-
ordered_select_mask = np.array([1, 2])
95+
class TestSelect:
10996

110-
var_value = np.array(impact_kwargs[var])[select_mask]
111-
# event_name is a list, convert to numpy array for indexing
112-
impact_fc = impact_forecast.select(**{var_select: var_value})
113-
# NOTE: Events keep their original order
114-
npt.assert_array_equal(
115-
impact_fc.event_id,
116-
impact_forecast.event_id[ordered_select_mask],
117-
)
118-
npt.assert_array_equal(
119-
impact_fc.event_name,
120-
np.array(impact_forecast.event_name)[ordered_select_mask],
121-
)
122-
npt.assert_array_equal(impact_fc.date, impact_forecast.date[ordered_select_mask])
123-
npt.assert_array_equal(
124-
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
125-
)
126-
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
127-
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
128-
npt.assert_array_equal(
129-
impact_fc.imp_mat.todense(),
130-
impact_forecast.imp_mat.todense()[ordered_select_mask],
97+
@pytest.mark.parametrize(
98+
"var, var_select",
99+
[("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")],
131100
)
101+
def test_base_class_select(
102+
self, impact_forecast, lead_time, member, impact_kwargs, var, var_select
103+
):
104+
"""Check if Impact.select works on the derived class"""
105+
select_mask = np.array([2, 1])
106+
ordered_select_mask = np.array([1, 2])
107+
if var == "date":
108+
# Date needs to be a valid delta
109+
select_mask = np.array([1, 2])
110+
ordered_select_mask = np.array([1, 2])
111+
112+
var_value = np.array(impact_kwargs[var])[select_mask]
113+
# event_name is a list, convert to numpy array for indexing
114+
impact_fc = impact_forecast.select(**{var_select: var_value})
115+
# NOTE: Events keep their original order
116+
npt.assert_array_equal(
117+
impact_fc.event_id,
118+
impact_forecast.event_id[ordered_select_mask],
119+
)
120+
npt.assert_array_equal(
121+
impact_fc.event_name,
122+
np.array(impact_forecast.event_name)[ordered_select_mask],
123+
)
124+
npt.assert_array_equal(
125+
impact_fc.date, impact_forecast.date[ordered_select_mask]
126+
)
127+
npt.assert_array_equal(
128+
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
129+
)
130+
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
131+
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
132+
npt.assert_array_equal(
133+
impact_fc.imp_mat.todense(),
134+
impact_forecast.imp_mat.todense()[ordered_select_mask],
135+
)
132136

137+
def test_impact_forecast_select_exposure(
138+
self, impact_forecast, lead_time, member, impact_kwargs
139+
):
140+
"""Check if Impact.select works on the derived class"""
141+
exp_col = 0
142+
select_mask = np.array([exp_col])
143+
coord_exp = impact_kwargs["coord_exp"][select_mask]
144+
impact_fc = impact_forecast.select(coord_exp=coord_exp)
145+
npt.assert_array_equal(impact_fc.member, member)
146+
npt.assert_array_equal(impact_fc.lead_time, lead_time)
147+
npt.assert_array_equal(
148+
impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col]
149+
)
133150

134-
def test_impact_forecast_select_exposure(
135-
impact_forecast, lead_time, member, impact_kwargs
136-
):
137-
"""Check if Impact.select works on the derived class"""
138-
exp_col = 0
139-
select_mask = np.array([exp_col])
140-
coord_exp = impact_kwargs["coord_exp"][select_mask]
141-
impact_fc = impact_forecast.select(coord_exp=coord_exp)
142-
npt.assert_array_equal(impact_fc.member, member)
143-
npt.assert_array_equal(impact_fc.lead_time, lead_time)
144-
npt.assert_array_equal(
145-
impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col]
146-
)
151+
def test_derived_select_single(self, impact_forecast, lead_time, member):
152+
imp_fc_select = impact_forecast.select(member=[2, 0])
153+
idx = np.array([0, 2])
154+
npt.assert_array_equal(imp_fc_select.event_id, impact_forecast.event_id[idx])
155+
npt.assert_array_equal(imp_fc_select.member, member[idx])
156+
npt.assert_array_equal(imp_fc_select.lead_time, lead_time[idx])
157+
158+
imp_fc_select = impact_forecast.select(lead_time=lead_time[np.array([2, 0])])
159+
npt.assert_array_equal(imp_fc_select.event_id, impact_forecast.event_id[idx])
160+
npt.assert_array_equal(imp_fc_select.member, member[idx])
161+
npt.assert_array_equal(imp_fc_select.lead_time, lead_time[idx])
162+
163+
def test_derived_select_intersections(
164+
self, impact_forecast, lead_time, member, impact_kwargs
165+
):
166+
imp_fc_select = impact_forecast.select(event_ids=[10, 14], member=[0, 1, 2])
167+
npt.assert_array_equal(
168+
imp_fc_select.event_id, impact_forecast.event_id[np.array([0])]
169+
)
170+
171+
imp_fc_select = impact_forecast.select(
172+
event_ids=[10, 11, 13], member=[0, 1, 2], lead_time=lead_time[1:3]
173+
)
174+
npt.assert_array_equal(
175+
imp_fc_select.event_id, impact_forecast.event_id[np.array([1])]
176+
)
177+
178+
# Test "outer"
179+
impact_forecast2 = ImpactForecast(
180+
lead_time=lead_time,
181+
member=np.zeros_like(member, dtype="int"),
182+
**impact_kwargs,
183+
)
184+
imp_fc_select = impact_forecast2.select(event_ids=[10, 11, 13], member=[0])
185+
npt.assert_array_equal(imp_fc_select.event_id, [10, 11, 13])
186+
npt.assert_array_equal(imp_fc_select.member, [0, 0, 0])
187+
188+
def test_no_select(self, impact_forecast, impact_kwargs):
189+
imp_fc_select = impact_forecast.select()
190+
npt.assert_array_equal(
191+
imp_fc_select.imp_mat.todense(), impact_forecast.imp_mat.todense()
192+
)
193+
194+
num_centroids = len(impact_kwargs["coord_exp"])
195+
imp_fc_select = impact_forecast.select(event_names=["aaaaa", "foo"])
196+
assert imp_fc_select.imp_mat.shape == (0, num_centroids)
197+
imp_fc_select = impact_forecast.select(event_ids=[-1, 1002])
198+
assert imp_fc_select.imp_mat.shape == (0, num_centroids)
199+
imp_fc_select = impact_forecast.select(member=[-1])
200+
assert imp_fc_select.imp_mat.shape == (0, num_centroids)
201+
imp_fc_select = impact_forecast.select(np.timedelta64("3", "Y"))
202+
assert imp_fc_select.imp_mat.shape == (0, num_centroids)
147203

148204

149205
@pytest.mark.skip("Concat from base class does not work")

0 commit comments

Comments
 (0)