Skip to content

Commit dd78fac

Browse files
author
Chahan Kropf
committed
Rewrite select test with test parameters
1 parent 9249e01 commit dd78fac

File tree

2 files changed

+76
-68
lines changed

2 files changed

+76
-68
lines changed

climada/engine/test/test_impact_forecast.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -84,42 +84,49 @@ def test_impact_forecast_from_impact(
8484
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
8585

8686

87-
def test_impact_forecast_select(impact_forecast, lead_time, member, impact_kwargs):
87+
@pytest.mark.parametrize(
88+
"var, var_select",
89+
[("event_id", "event_ids"), ("event_name", "event_names"), ("date", "dates")],
90+
)
91+
def test_impact_forecast_select_events(
92+
impact_forecast, lead_time, member, impact_kwargs, var, var_select
93+
):
8894
"""Check if Impact.select works on the derived class"""
8995
select_mask = np.array([2, 1])
9096
ordered_select_mask = np.array([1, 2])
91-
vars_to_select = {
92-
"event_id": "event_ids",
93-
"event_name": "event_names",
94-
"date": "dates",
95-
}
96-
97-
for var, var_select in vars_to_select.items():
98-
var_value = np.array(impact_kwargs[var])[select_mask]
99-
# event_name is a list, convert to numpy array for indexing
100-
impact_fc = impact_forecast.select(**{var_select: var_value})
101-
# NOTE: Events keep their original order
102-
npt.assert_array_equal(
103-
impact_fc.event_id,
104-
impact_forecast.event_id[ordered_select_mask],
105-
)
106-
npt.assert_array_equal(
107-
impact_fc.event_name,
108-
np.array(impact_forecast.event_name)[ordered_select_mask],
109-
)
110-
npt.assert_array_equal(
111-
impact_fc.date, impact_forecast.date[ordered_select_mask]
112-
)
113-
npt.assert_array_equal(
114-
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
115-
)
116-
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
117-
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
118-
npt.assert_array_equal(
119-
impact_fc.imp_mat.todense(),
120-
impact_forecast.imp_mat.todense()[ordered_select_mask],
121-
)
97+
if var == "date":
98+
# Date needs to be a valid delta
99+
select_mask = np.array([1, 2])
100+
ordered_select_mask = np.array([1, 2])
101+
102+
var_value = np.array(impact_kwargs[var])[select_mask]
103+
# event_name is a list, convert to numpy array for indexing
104+
impact_fc = impact_forecast.select(**{var_select: var_value})
105+
# NOTE: Events keep their original order
106+
npt.assert_array_equal(
107+
impact_fc.event_id,
108+
impact_forecast.event_id[ordered_select_mask],
109+
)
110+
npt.assert_array_equal(
111+
impact_fc.event_name,
112+
np.array(impact_forecast.event_name)[ordered_select_mask],
113+
)
114+
npt.assert_array_equal(impact_fc.date, impact_forecast.date[ordered_select_mask])
115+
npt.assert_array_equal(
116+
impact_fc.frequency, impact_forecast.frequency[ordered_select_mask]
117+
)
118+
npt.assert_array_equal(impact_fc.member, member[ordered_select_mask])
119+
npt.assert_array_equal(impact_fc.lead_time, lead_time[ordered_select_mask])
120+
npt.assert_array_equal(
121+
impact_fc.imp_mat.todense(),
122+
impact_forecast.imp_mat.todense()[ordered_select_mask],
123+
)
124+
122125

126+
def test_impact_forecast_select_exposure(
127+
impact_forecast, lead_time, member, impact_kwargs
128+
):
129+
"""Check if Impact.select works on the derived class"""
123130
exp_col = 0
124131
select_mask = np.array([exp_col])
125132
coord_exp = impact_kwargs["coord_exp"][select_mask]

climada/hazard/test/test_forecast.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -100,42 +100,43 @@ def test_hazard_forecast_concat(haz_fc, lead_time, member):
100100
npt.assert_array_equal(haz_fc_concat.member, np.concatenate([member, member]))
101101

102102

103-
def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs):
103+
@pytest.mark.parametrize(
104+
"var, var_select",
105+
[("event_id", "event_id"), ("event_name", "event_names"), ("date", "date")],
106+
)
107+
def test_hazard_forecast_select(haz_fc, lead_time, member, haz_kwargs, var, var_select):
104108
"""Check if Hazard.select works on the derived class"""
105109

106110
select_mask = np.array([3, 2])
107-
ordered_select_mask = np.array([2, 3])
108-
vars_to_select = {
109-
"event_id": "event_id",
110-
"event_name": "event_names",
111-
"date": "date",
112-
}
113-
114-
for var, var_select in vars_to_select.items():
115-
var_value = np.array(haz_kwargs[var])[select_mask]
116-
# event_name is a list, convert to numpy array for indexing
117-
haz_fc_sel = haz_fc.select(**{var_select: var_value})
118-
npt.assert_array_equal(
119-
haz_fc_sel.event_id,
120-
haz_fc.event_id[ordered_select_mask],
121-
)
122-
npt.assert_array_equal(
123-
haz_fc_sel.event_name,
124-
np.array(haz_fc.event_name)[ordered_select_mask],
125-
)
126-
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
127-
npt.assert_array_equal(
128-
haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask]
129-
)
130-
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
131-
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
132-
npt.assert_array_equal(
133-
haz_fc_sel.intensity.todense(),
134-
haz_fc.intensity.todense()[ordered_select_mask],
135-
)
136-
npt.assert_array_equal(
137-
haz_fc_sel.fraction.todense(),
138-
haz_fc.fraction.todense()[ordered_select_mask],
139-
)
140-
141-
assert haz_fc_sel.centroids == haz_fc.centroids
111+
ordered_select_mask = np.array([3, 2])
112+
if var == "date":
113+
# Date needs to be a valid delta
114+
select_mask = np.array([2, 3])
115+
ordered_select_mask = np.array([2, 3])
116+
117+
var_value = np.array(haz_kwargs[var])[select_mask]
118+
# event_name is a list, convert to numpy array for indexing
119+
haz_fc_sel = haz_fc.select(**{var_select: var_value})
120+
# Note: order is preserved
121+
npt.assert_array_equal(
122+
haz_fc_sel.event_id,
123+
haz_fc.event_id[ordered_select_mask],
124+
)
125+
npt.assert_array_equal(
126+
haz_fc_sel.event_name,
127+
np.array(haz_fc.event_name)[ordered_select_mask],
128+
)
129+
npt.assert_array_equal(haz_fc_sel.date, haz_fc.date[ordered_select_mask])
130+
npt.assert_array_equal(haz_fc_sel.frequency, haz_fc.frequency[ordered_select_mask])
131+
npt.assert_array_equal(haz_fc_sel.member, member[ordered_select_mask])
132+
npt.assert_array_equal(haz_fc_sel.lead_time, lead_time[ordered_select_mask])
133+
npt.assert_array_equal(
134+
haz_fc_sel.intensity.todense(),
135+
haz_fc.intensity.todense()[ordered_select_mask],
136+
)
137+
npt.assert_array_equal(
138+
haz_fc_sel.fraction.todense(),
139+
haz_fc.fraction.todense()[ordered_select_mask],
140+
)
141+
142+
assert haz_fc_sel.centroids == haz_fc.centroids

0 commit comments

Comments
 (0)