Skip to content

Commit 347fcf1

Browse files
committed
Add test for ImpactForecast.select
1 parent 44bf09a commit 347fcf1

File tree

1 file changed

+47
-26
lines changed

1 file changed

+47
-26
lines changed

climada/engine/test/test_impact_forecast.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,35 +40,56 @@ def impact(impact_kwargs):
4040
return Impact(**impact_kwargs)
4141

4242

43-
def assert_impact_kwargs(impact: Impact, **kwargs):
44-
for key, value in kwargs.items():
45-
attr = getattr(impact, key)
46-
if isinstance(value, (np.ndarray, list)):
47-
npt.assert_array_equal(attr, value)
48-
elif isinstance(value, csr_matrix):
49-
npt.assert_array_equal(attr.todense(), value.todense())
50-
else:
51-
assert attr == value
43+
@pytest.fixture
44+
def lead_time():
45+
return pd.date_range("2000-01-01", "2000-01-02", periods=6).to_numpy()
5246

5347

54-
class TestImpactForecastInit:
55-
lead_time = pd.date_range("2000-01-01", "2000-01-02", periods=6).to_numpy()
56-
member = np.arange(6)
48+
@pytest.fixture
49+
def member():
50+
return np.arange(6)
51+
52+
53+
@pytest.fixture
54+
def impact_forecast(impact, lead_time, member):
55+
return ImpactForecast.from_impact(impact, lead_time=lead_time, member=member)
56+
5757

58-
def test_impact_forecast_init(self, impact_kwargs):
58+
class TestImpactForecastInit:
59+
def assert_impact_kwargs(self, impact: Impact, **kwargs):
60+
for key, value in kwargs.items():
61+
attr = getattr(impact, key)
62+
if isinstance(value, (np.ndarray, list)):
63+
npt.assert_array_equal(attr, value)
64+
elif isinstance(value, csr_matrix):
65+
npt.assert_array_equal(attr.todense(), value.todense())
66+
else:
67+
assert attr == value
68+
69+
def test_impact_forecast_init(self, impact_kwargs, lead_time, member):
5970
forecast1 = ImpactForecast(
60-
lead_time=self.lead_time,
61-
member=self.member,
71+
lead_time=lead_time,
72+
member=member,
6273
**impact_kwargs,
6374
)
64-
npt.assert_array_equal(forecast1.lead_time, self.lead_time)
65-
npt.assert_array_equal(forecast1.member, self.member)
66-
assert_impact_kwargs(forecast1, **impact_kwargs)
67-
68-
def test_impact_forecast_from_impact(self, impact, impact_kwargs):
69-
forecast = ImpactForecast.from_impact(
70-
impact, lead_time=self.lead_time, member=self.member
71-
)
72-
npt.assert_array_equal(forecast.lead_time, self.lead_time)
73-
npt.assert_array_equal(forecast.member, self.member)
74-
assert_impact_kwargs(forecast, **impact_kwargs)
75+
npt.assert_array_equal(forecast1.lead_time, lead_time)
76+
npt.assert_array_equal(forecast1.member, member)
77+
self.assert_impact_kwargs(forecast1, **impact_kwargs)
78+
79+
def test_impact_forecast_from_impact(
80+
self, impact_forecast, impact_kwargs, lead_time, member
81+
):
82+
npt.assert_array_equal(impact_forecast.lead_time, lead_time)
83+
npt.assert_array_equal(impact_forecast.member, member)
84+
self.assert_impact_kwargs(impact_forecast, **impact_kwargs)
85+
86+
87+
def test_impact_forecast_select(impact_forecast, lead_time, member):
88+
"""Check if Impact.select works on the derived class"""
89+
impact_fc = impact_forecast.select(event_ids=[12, 10])
90+
# NOTE: Events keep their original order
91+
npt.assert_array_equal(
92+
impact_fc.event_id, impact_forecast.event_id[np.array([0, 2])]
93+
)
94+
npt.assert_array_equal(impact_fc.member, member[np.array([0, 2])])
95+
npt.assert_array_equal(impact_fc.lead_time, lead_time[np.array([0, 2])])

0 commit comments

Comments
 (0)