Skip to content

Commit 396239a

Browse files
author
Chahan Kropf
committed
Add extended select impact forecast test
1 parent a386d22 commit 396239a

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

climada/engine/test/test_impact_forecast.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,49 @@ def test_impact_forecast_from_impact(
8686

8787
def test_impact_forecast_select(impact_forecast, lead_time, member, impact_kwargs):
8888
"""Check if Impact.select works on the derived class"""
89-
event_ids = impact_kwargs["event_id"][np.array([2, 0])]
90-
impact_fc = impact_forecast.select(event_ids=event_ids)
91-
# NOTE: Events keep their original order
89+
select_mask = np.array([1, 2])
90+
ordered_select_mask = np.array([1, 2])
91+
vars_to_select = {
92+
"event_id": "event_ids",
93+
"event_name": "event_names",
94+
"date": "dates",
95+
}
96+
97+
for var, var_select in vars_to_select.items():
98+
var_value = np.array(impact_kwargs[var])[select_mask]
99+
# event_name is a list, convert to numpy array for indexing
100+
impact_fc = impact_forecast.select(**{var_select: var_value})
101+
# NOTE: Events keep their original order
102+
npt.assert_array_equal(
103+
impact_fc.event_id,
104+
impact_forecast.event_id[ordered_select_mask],
105+
)
106+
npt.assert_array_equal(
107+
impact_fc.event_name,
108+
np.array(impact_forecast.event_name)[ordered_select_mask],
109+
)
110+
npt.assert_array_equal(
111+
impact_fc.date, impact_forecast.date[ordered_select_mask]
112+
)
113+
npt.assert_array_equal(
114+
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
115+
)
116+
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
117+
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
118+
npt.assert_array_equal(
119+
impact_fc.imp_mat.todense(),
120+
impact_forecast.imp_mat.todense()[ordered_select_mask],
121+
)
122+
123+
exp_col = 0
124+
select_mask = np.array([exp_col])
125+
coord_exp = impact_kwargs["coord_exp"][select_mask]
126+
impact_fc = impact_forecast.select(coord_exp=coord_exp)
127+
npt.assert_array_equal(impact_fc.member, member)
128+
npt.assert_array_equal(impact_fc.lead_time, lead_time)
92129
npt.assert_array_equal(
93-
impact_fc.event_id, impact_forecast.event_id[np.array([0, 2])]
130+
impact_fc.imp_mat.todense(), impact_forecast.imp_mat.todense()[:, exp_col]
94131
)
95-
npt.assert_array_equal(impact_fc.member, member[np.array([0, 2])])
96-
npt.assert_array_equal(impact_fc.lead_time, lead_time[np.array([0, 2])])
97132

98133

99134
@pytest.mark.skip("Concat from base class does not work")

0 commit comments

Comments
 (0)