Skip to content

Commit d3cc83a

Browse files
peanutfunEvelyn-M
andauthored
Check if Impact.select works on ImpactForecast (#1170)
* Add test for ImpactForecast.select * minor test modification --------- Co-authored-by: Evelyn Mühlhofer <[email protected]>
1 parent d04fecf commit d3cc83a

File tree

1 file changed

+49
-25
lines changed

1 file changed

+49
-25
lines changed

climada/engine/test/test_impact_forecast.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,35 +40,59 @@ def impact(impact_kwargs):
4040
return Impact(**impact_kwargs)
4141

4242

43-
def assert_impact_kwargs(impact: Impact, **kwargs):
44-
for key, value in kwargs.items():
45-
attr = getattr(impact, key)
46-
if isinstance(value, (np.ndarray, list)):
47-
npt.assert_array_equal(attr, value)
48-
elif isinstance(value, csr_matrix):
49-
npt.assert_array_equal(attr.todense(), value.todense())
50-
else:
51-
assert attr == value
43+
@pytest.fixture
44+
def lead_time():
45+
return pd.timedelta_range(start="1 day", periods=6).to_numpy()
46+
47+
48+
@pytest.fixture
49+
def member():
50+
return np.arange(6)
51+
52+
53+
@pytest.fixture
54+
def impact_forecast(impact, lead_time, member):
55+
return ImpactForecast.from_impact(impact, lead_time=lead_time, member=member)
5256

5357

5458
class TestImpactForecastInit:
55-
lead_time = pd.timedelta_range(start="1 day", periods=6).to_numpy()
56-
member = np.arange(6)
5759

58-
def test_impact_forecast_init(self, impact_kwargs):
60+
def assert_impact_kwargs(self, impact: Impact, **kwargs):
61+
for key, value in kwargs.items():
62+
attr = getattr(impact, key)
63+
if isinstance(value, (np.ndarray, list)):
64+
npt.assert_array_equal(attr, value)
65+
elif isinstance(value, csr_matrix):
66+
npt.assert_array_equal(attr.todense(), value.todense())
67+
else:
68+
assert attr == value
69+
70+
def test_impact_forecast_init(self, impact_kwargs, lead_time, member):
5971
forecast1 = ImpactForecast(
60-
lead_time=self.lead_time,
61-
member=self.member,
72+
lead_time=lead_time,
73+
member=member,
6274
**impact_kwargs,
6375
)
64-
npt.assert_array_equal(forecast1.lead_time, self.lead_time)
65-
npt.assert_array_equal(forecast1.member, self.member)
66-
assert_impact_kwargs(forecast1, **impact_kwargs)
67-
68-
def test_impact_forecast_from_impact(self, impact, impact_kwargs):
69-
forecast = ImpactForecast.from_impact(
70-
impact, lead_time=self.lead_time, member=self.member
71-
)
72-
npt.assert_array_equal(forecast.lead_time, self.lead_time)
73-
npt.assert_array_equal(forecast.member, self.member)
74-
assert_impact_kwargs(forecast, **impact_kwargs)
76+
npt.assert_array_equal(forecast1.lead_time, lead_time)
77+
npt.assert_array_equal(forecast1.member, member)
78+
self.assert_impact_kwargs(forecast1, **impact_kwargs)
79+
80+
def test_impact_forecast_from_impact(
81+
self, impact_forecast, impact_kwargs, lead_time, member
82+
):
83+
npt.assert_array_equal(impact_forecast.lead_time, lead_time)
84+
npt.assert_array_equal(impact_forecast.member, member)
85+
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
86+
87+
88+
def test_impact_forecast_select(impact_forecast, lead_time, member, impact_kwargs):
89+
"""Check if Impact.select works on the derived class"""
90+
91+
event_ids = impact_kwargs["event_id"][np.array([2, 0])]
92+
impact_fc = impact_forecast.select(event_ids=event_ids)
93+
# NOTE: Events keep their original order
94+
npt.assert_array_equal(
95+
impact_fc.event_id, impact_forecast.event_id[np.array([0, 2])]
96+
)
97+
npt.assert_array_equal(impact_fc.member, member[np.array([0, 2])])
98+
npt.assert_array_equal(impact_fc.lead_time, lead_time[np.array([0, 2])])

0 commit comments

Comments
 (0)