Skip to content

Commit c1eefdd

Browse files
add HazardForecast.select and extent test
1 parent 48b6d40 commit c1eefdd

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

climada/hazard/forecast.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,43 @@ def _check_sizes(self):
104104
num_entries = len(self.event_id)
105105
size(exp_len=num_entries, var=self.member, var_name="Forecast.member")
106106
size(exp_len=num_entries, var=self.lead_time, var_name="Forecast.lead_time")
107+
108+
def select(
109+
self,
110+
event_names=None,
111+
event_id=None,
112+
date=None,
113+
orig=None,
114+
reg_id=None,
115+
extent=None,
116+
reset_frequency=False,
117+
member=None,
118+
lead_time=None,
119+
):
120+
if member is not None or lead_time is not None:
121+
mask_member = (
122+
self.idx_member(member)
123+
if member is not None
124+
else np.full_like(self.member, True, dtype=bool)
125+
)
126+
mask_lead_time = (
127+
self.idx_lead_time(lead_time)
128+
if lead_time is not None
129+
else np.full_like(self.lead_time, True, dtype=bool)
130+
)
131+
mask_event_id = np.asarray(self.event_id)[(mask_member & mask_lead_time)]
132+
event_id = (
133+
np.intersect1d(event_id, mask_event_id)
134+
if event_id is not None
135+
else mask_event_id
136+
)
137+
138+
return super().select(
139+
event_names=event_names,
140+
event_id=event_id,
141+
date=date,
142+
orig=orig,
143+
reg_id=reg_id,
144+
extent=extent,
145+
reset_frequency=reset_frequency,
146+
)

climada/hazard/test/test_forecast.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,24 @@ def test_hazard_forecast_select(haz_fc, lead_time, member):
115115
npt.assert_array_equal(haz_fc_select.member, member[np.array([3, 0])])
116116
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([3, 0])])
117117

118+
haz_fc_select = haz_fc.select(member=[3, 0])
119+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0, 3])])
120+
npt.assert_array_equal(haz_fc_select.member, member[np.array([0, 3])])
121+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([0, 3])])
122+
123+
haz_fc_select = haz_fc.select(lead_time=lead_time[np.array([3, 0])])
124+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0, 3])])
125+
npt.assert_array_equal(haz_fc_select.member, member[np.array([0, 3])])
126+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([0, 3])])
127+
128+
haz_fc_select = haz_fc.select(event_id=[1, 4], member=[0, 1, 2])
129+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0])])
130+
131+
haz_fc_select = haz_fc.select(
132+
event_id=[1, 2, 4], member=[0, 1, 2], lead_time=lead_time[1:3]
133+
)
134+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([1])])
135+
118136

119137
def test_write_read_hazard_forecast(haz_fc, tmp_path):
120138

0 commit comments

Comments
 (0)