Skip to content

Commit 574a3a7

Browse files
first draft impact forecast select
1 parent 9d6fef9 commit 574a3a7

File tree

2 files changed

+116
-48
lines changed

2 files changed

+116
-48
lines changed

climada/engine/impact_forecast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,39 @@ def _check_sizes(self):
184184
num_entries = len(self.event_id)
185185
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
186186
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
187+
188+
def select(
189+
self,
190+
event_ids=None,
191+
event_names=None,
192+
dates=None,
193+
coord_exp=None,
194+
reset_frequency=False,
195+
member=None,
196+
lead_time=None,
197+
):
198+
if member is not None or lead_time is not None:
199+
mask_member = (
200+
self.idx_member(member)
201+
if member is not None
202+
else np.full_like(self.member, True, dtype=bool)
203+
)
204+
mask_lead_time = (
205+
self.idx_lead_time(lead_time)
206+
if lead_time is not None
207+
else np.full_like(self.lead_time, True, dtype=bool)
208+
)
209+
mask_event_id = np.asarray(self.event_id)[(mask_member & mask_lead_time)]
210+
event_ids = (
211+
np.intersect1d(event_ids, mask_event_id)
212+
if event_ids is not None
213+
else mask_event_id
214+
)
215+
216+
return super().select(
217+
event_ids=event_ids,
218+
event_names=event_names,
219+
dates=dates,
220+
coord_exp=coord_exp,
221+
reset_frequency=reset_frequency,
222+
)

climada/engine/test/test_impact_forecast.py

Lines changed: 80 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -92,58 +92,90 @@ def test_impact_forecast_from_impact(
9292
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
9393

9494

95-
@pytest.mark.parametrize(
96-
"var, var_select",
97-
[("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")],
98-
)
99-
def test_impact_forecast_select_events(
100-
impact_forecast, lead_time, member, impact_kwargs, var, var_select
101-
):
102-
"""Check if Impact.select works on the derived class"""
103-
select_mask = np.array([2, 1])
104-
ordered_select_mask = np.array([1, 2])
105-
if var == "date":
106-
# Date needs to be a valid delta
107-
select_mask = np.array([1, 2])
108-
ordered_select_mask = np.array([1, 2])
95+
class TestSelect:
10996

110-
var_value = np.array(impact_kwargs[var])[select_mask]
111-
# event_name is a list, convert to numpy array for indexing
112-
impact_fc = impact_forecast.select(**{var_select: var_value})
113-
# NOTE: Events keep their original order
114-
npt.assert_array_equal(
115-
impact_fc.event_id,
116-
impact_forecast.event_id[ordered_select_mask],
117-
)
118-
npt.assert_array_equal(
119-
impact_fc.event_name,
120-
np.array(impact_forecast.event_name)[ordered_select_mask],
121-
)
122-
npt.assert_array_equal(impact_fc.date, impact_forecast.date[ordered_select_mask])
123-
npt.assert_array_equal(
124-
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
125-
)
126-
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
127-
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
128-
npt.assert_array_equal(
129-
impact_fc.imp_mat.todense(),
130-
impact_forecast.imp_mat.todense()[ordered_select_mask],
97+
@pytest.mark.parametrize(
98+
"var, var_select",
99+
[("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")],
131100
)
101+
def test_base_class_select(
102+
self, impact_forecast, lead_time, member, impact_kwargs, var, var_select
103+
):
104+
"""Check if Impact.select works on the derived class"""
105+
select_mask = np.array([2, 1])
106+
ordered_select_mask = np.array([1, 2])
107+
if var == "date":
108+
# Date needs to be a valid delta
109+
select_mask = np.array([1, 2])
110+
ordered_select_mask = np.array([1, 2])
111+
112+
var_value = np.array(impact_kwargs[var])[select_mask]
113+
# event_name is a list, convert to numpy array for indexing
114+
impact_fc = impact_forecast.select(**{var_select: var_value})
115+
# NOTE: Events keep their original order
116+
npt.assert_array_equal(
117+
impact_fc.event_id,
118+
impact_forecast.event_id[ordered_select_mask],
119+
)
120+
npt.assert_array_equal(
121+
impact_fc.event_name,
122+
np.array(impact_forecast.event_name)[ordered_select_mask],
123+
)
124+
npt.assert_array_equal(
125+
impact_fc.date, impact_forecast.date[ordered_select_mask]
126+
)
127+
npt.assert_array_equal(
128+
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
129+
)
130+
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
131+
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
132+
npt.assert_array_equal(
133+
impact_fc.imp_mat.todense(),
134+
impact_forecast.imp_mat.todense()[ordered_select_mask],
135+
)
132136

137+
def test_impact_forecast_select_exposure(
138+
self, impact_forecast, lead_time, member, impact_kwargs
139+
):
140+
"""Check if Impact.select works on the derived class"""
141+
exp_col = 0
142+
select_mask = np.array([exp_col])
143+
coord_exp = impact_kwargs["coord_exp"][select_mask]
144+
impact_fc = impact_forecast.select(coord_exp=coord_exp)
145+
npt.assert_array_equal(impact_fc.member, member)
146+
npt.assert_array_equal(impact_fc.lead_time, lead_time)
147+
npt.assert_array_equal(
148+
impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col]
149+
)
133150

134-
def test_impact_forecast_select_exposure(
135-
impact_forecast, lead_time, member, impact_kwargs
136-
):
137-
"""Check if Impact.select works on the derived class"""
138-
exp_col = 0
139-
select_mask = np.array([exp_col])
140-
coord_exp = impact_kwargs["coord_exp"][select_mask]
141-
impact_fc = impact_forecast.select(coord_exp=coord_exp)
142-
npt.assert_array_equal(impact_fc.member, member)
143-
npt.assert_array_equal(impact_fc.lead_time, lead_time)
144-
npt.assert_array_equal(
145-
impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col]
146-
)
151+
def test_derived_select(self, haz_fc, lead_time, member, haz_kwargs):
152+
haz_fc_select = haz_fc.select(member=[3, 0])
153+
idx = np.array([0, 3])
154+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx])
155+
npt.assert_array_equal(haz_fc_select.member, member[idx])
156+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx])
157+
158+
haz_fc_select = haz_fc.select(lead_time=lead_time[np.array([3, 0])])
159+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx])
160+
npt.assert_array_equal(haz_fc_select.member, member[idx])
161+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx])
162+
163+
# Test intersections
164+
haz_fc_select = haz_fc.select(event_id=[1, 4], member=[0, 1, 2])
165+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0])])
166+
167+
haz_fc_select = haz_fc.select(
168+
event_id=[1, 2, 4], member=[0, 1, 2], lead_time=lead_time[1:3]
169+
)
170+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([1])])
171+
172+
# Test "outer"
173+
haz_fc2 = HazardForecast(
174+
lead_time=lead_time, member=np.zeros_like(member, dtype="int"), **haz_kwargs
175+
)
176+
haz_fc_select = haz_fc2.select(event_id=[1, 2, 4], member=[0])
177+
npt.assert_array_equal(haz_fc_select.event_id, [1, 2, 4])
178+
npt.assert_array_equal(haz_fc_select.member, [0, 0, 0])
147179

148180

149181
@pytest.mark.skip("Concat from base class does not work")

0 commit comments

Comments
 (0)