Skip to content

Commit 781cea6

Browse files
Merge branch 'idx_member_leadtime' into add_sel_to_haz_and_imp_forecast
2 parents c1eefdd + 12bf2b6 commit 781cea6

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

climada/util/forecast.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,35 @@ def __init__(
5656
)
5757
self.member = np.asarray(member) if member is not None else np.array([])
5858
super().__init__(**kwargs)
59+
60+
def idx_member(self, member: np.ndarray) -> np.ndarray:
61+
"""Return boolean array where self.member == member using numpy.isin()
62+
63+
Parameters
64+
----------
65+
member : np.ndarray
66+
Array of ensemble members (ints) for which to return an indexer
67+
68+
Returns
69+
-------
70+
np.ndarray
71+
Boolean array where self.member is in member.
72+
"""
73+
74+
return np.isin(self.member, member)
75+
76+
def idx_lead_time(self, lead_time: np.ndarray) -> np.ndarray:
77+
"""Return boolean array where self.lead_time == lead_time using numpy.isin()
78+
79+
Parameters
80+
----------
81+
lead_time : np.ndarray
82+
Array of lead times (numpy.timedelta64) for which to return an indexer
83+
84+
Returns
85+
-------
86+
np.ndarray
87+
Boolean array where self.lead_time is in lead_time.
88+
"""
89+
90+
return np.isin(self.lead_time, lead_time)

climada/util/test/test_forecast.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,40 @@ def test_forecast_init():
5050
forecast = Forecast(lead_time=lead_times_seconds, member=[1, 2, 3])
5151
npt.assert_array_equal(forecast.lead_time, lead_times_seconds, strict=True)
5252
assert forecast.lead_time.dtype == np.dtype("timedelta64[ns]")
53+
54+
55+
def test_idx_member():
56+
"""Test idx_member method of Forecast class."""
57+
forecast = Forecast(member=np.array([1, 2, 3, 4]))
58+
59+
idx = forecast.idx_member(1)
60+
npt.assert_array_equal(idx, np.array([True, False, False, False]), strict=True)
61+
62+
idx = forecast.idx_member(np.array([2, 4]))
63+
npt.assert_array_equal(idx, np.array([False, True, False, True]), strict=True)
64+
65+
idx = forecast.idx_member([2, 4])
66+
npt.assert_array_equal(idx, np.array([False, True, False, True]), strict=True)
67+
68+
idx = forecast.idx_member(None)
69+
npt.assert_array_equal(idx, np.array([False, False, False, False]), strict=True)
70+
71+
72+
def test_idx_lead_time():
73+
"""Test idx_lead_time method of Forecast class."""
74+
forecast = Forecast(
75+
lead_time=pd.timedelta_range(start="1 day", periods=4).to_numpy()
76+
)
77+
78+
idx = forecast.idx_lead_time(
79+
pd.timedelta_range(start="1 day", periods=4).to_numpy()[::2]
80+
)
81+
npt.assert_array_equal(idx, np.array([True, False, True, False]), strict=True)
82+
83+
idx = forecast.idx_lead_time(
84+
pd.timedelta_range(start="1 day", periods=4).to_numpy()[0]
85+
)
86+
npt.assert_array_equal(idx, np.array([True, False, False, False]), strict=True)
87+
88+
idx = forecast.idx_lead_time(None)
89+
npt.assert_array_equal(idx, np.array([False, False, False, False]), strict=True)

0 commit comments

Comments
 (0)