Skip to content

Commit 64aec3d

Browse files
committed
Update docstring, reoganize tests
1 parent db1f23d commit 64aec3d

File tree

2 files changed

+96
-59
lines changed

2 files changed

+96
-59
lines changed

climada/hazard/forecast.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import logging
23+
from typing import Self
2324

2425
import numpy as np
2526

@@ -107,16 +108,35 @@ def _check_sizes(self):
107108

108109
def select(
109110
self,
111+
member=None,
112+
lead_time=None,
110113
event_names=None,
111114
event_id=None,
112115
date=None,
113116
orig=None,
114117
reg_id=None,
115118
extent=None,
116119
reset_frequency=False,
117-
member=None,
118-
lead_time=None,
119-
):
120+
) -> Self:
121+
"""Select entries based on the parameters and return a new instance.
122+
123+
The selection will contain the intersection of all given parameters.
124+
125+
Parameters
126+
----------
127+
member : Sequence of ints
128+
Ensemble members to select
129+
lead_time : Sequence of numpy.timedelta64
130+
Lead times to select
131+
132+
Returns
133+
-------
134+
HazardForecast
135+
136+
See Also
137+
--------
138+
:py:meth:`~climada.hazard.base.Hazard.select`
139+
"""
120140
if member is not None or lead_time is not None:
121141
mask_member = (
122142
self.idx_member(member)

climada/hazard/test/test_forecast.py

Lines changed: 73 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -107,64 +107,81 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member):
107107
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
108108

109109

110-
@pytest.mark.parametrize(
111-
"var, var_select",
112-
[("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")],
113-
)
114-
def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_select):
115-
"""Check if Hazard.select works on the derived class"""
116-
117-
select_mask = np.array([3, 2])
118-
ordered_select_mask = np.array([3, 2])
119-
if var == "date":
120-
# Date needs to be a valid delta
121-
select_mask = np.array([2, 3])
122-
ordered_select_mask = np.array([2, 3])
123-
124-
var_value = np.array(haz_kwargs[var])[select_mask]
125-
# event_name is a list, convert to numpy array for indexing
126-
haz_fc_sel = haz_fc.select(**{var_select: var_value})
127-
# Note: order is preserved
128-
npt.assert_array_equal(
129-
haz_fc_sel.event_id,
130-
haz_fc.event_id[ordered_select_mask],
131-
)
132-
npt.assert_array_equal(
133-
haz_fc_sel.event_name,
134-
np.array(haz_fc.event_name)[ordered_select_mask],
135-
)
136-
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
137-
npt.assert_array_equal(haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask])
138-
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
139-
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
140-
npt.assert_array_equal(
141-
haz_fc_sel.intensity.todense(),
142-
haz_fc.intensity.todense()[ordered_select_mask],
143-
)
144-
npt.assert_array_equal(
145-
haz_fc_sel.fraction.todense(),
146-
haz_fc.fraction.todense()[ordered_select_mask],
147-
)
148-
149-
assert haz_fc_sel.centroids == haz_fc.centroids
150-
151-
haz_fc_select = haz_fc.select(member=[3, 0])
152-
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0, 3])])
153-
npt.assert_array_equal(haz_fc_select.member, member[np.array([0, 3])])
154-
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([0, 3])])
155-
156-
haz_fc_select = haz_fc.select(lead_time=lead_time[np.array([3, 0])])
157-
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0, 3])])
158-
npt.assert_array_equal(haz_fc_select.member, member[np.array([0, 3])])
159-
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[np.array([0, 3])])
160-
161-
haz_fc_select = haz_fc.select(event_id=[1, 4], member=[0, 1, 2])
162-
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0])])
110+
class TestSelect:
163111

164-
haz_fc_select = haz_fc.select(
165-
event_id=[1, 2, 4], member=[0, 1, 2], lead_time=lead_time[1:3]
112+
@pytest.mark.parametrize(
113+
"var, var_select",
114+
[("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")],
166115
)
167-
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([1])])
116+
def test_base_class_select(
117+
self, haz_fc, lead_time, member, haz_kwargs, var, var_select
118+
):
119+
"""Check if Hazard.select works on the derived class"""
120+
121+
select_mask = np.array([3, 2])
122+
ordered_select_mask = np.array([3, 2])
123+
if var == "date":
124+
# Date needs to be a valid delta
125+
select_mask = np.array([2, 3])
126+
ordered_select_mask = np.array([2, 3])
127+
128+
var_value = np.array(haz_kwargs[var])[select_mask]
129+
# event_name is a list, convert to numpy array for indexing
130+
haz_fc_sel = haz_fc.select(**{var_select: var_value})
131+
# Note: order is preserved
132+
npt.assert_array_equal(
133+
haz_fc_sel.event_id,
134+
haz_fc.event_id[ordered_select_mask],
135+
)
136+
npt.assert_array_equal(
137+
haz_fc_sel.event_name,
138+
np.array(haz_fc.event_name)[ordered_select_mask],
139+
)
140+
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
141+
npt.assert_array_equal(
142+
haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask]
143+
)
144+
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
145+
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
146+
npt.assert_array_equal(
147+
haz_fc_sel.intensity.todense(),
148+
haz_fc.intensity.todense()[ordered_select_mask],
149+
)
150+
npt.assert_array_equal(
151+
haz_fc_sel.fraction.todense(),
152+
haz_fc.fraction.todense()[ordered_select_mask],
153+
)
154+
155+
assert haz_fc_sel.centroids == haz_fc.centroids
156+
157+
def test_derived_select(self, haz_fc, lead_time, member, haz_kwargs):
158+
haz_fc_select = haz_fc.select(member=[3, 0])
159+
idx = np.array([0, 3])
160+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx])
161+
npt.assert_array_equal(haz_fc_select.member, member[idx])
162+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx])
163+
164+
haz_fc_select = haz_fc.select(lead_time=lead_time[np.array([3, 0])])
165+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[idx])
166+
npt.assert_array_equal(haz_fc_select.member, member[idx])
167+
npt.assert_array_equal(haz_fc_select.lead_time, lead_time[idx])
168+
169+
# Test intersections
170+
haz_fc_select = haz_fc.select(event_id=[1, 4], member=[0, 1, 2])
171+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([0])])
172+
173+
haz_fc_select = haz_fc.select(
174+
event_id=[1, 2, 4], member=[0, 1, 2], lead_time=lead_time[1:3]
175+
)
176+
npt.assert_array_equal(haz_fc_select.event_id, haz_fc.event_id[np.array([1])])
177+
178+
# Test "outer"
179+
haz_fc2 = HazardForecast(
180+
lead_time=lead_time, member=np.zeros_like(member, dtype="int"), **haz_kwargs
181+
)
182+
haz_fc_select = haz_fc2.select(event_id=[1, 2, 4], member=[0])
183+
npt.assert_array_equal(haz_fc_select.event_id, [1, 2, 4])
184+
npt.assert_array_equal(haz_fc_select.member, [0, 0, 0])
168185

169186

170187
def test_write_read_hazard_forecast(haz_fc, tmp_path):

0 commit comments

Comments
 (0)