Skip to content

Commit 85b92ef

Browse files
committed
Make Impact.concat support ImpactForecast
1 parent 302be63 commit 85b92ef

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

climada/engine/impact.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,9 +2208,12 @@ def stack_attribute(attr_name: str) -> np.ndarray:
22082208
imp_mat = sparse.vstack(imp_mats)
22092209

22102210
# Concatenate other attributes
2211-
kwargs = {
2212-
attr: stack_attribute(attr) for attr in ("date", "frequency", "at_event")
2213-
}
2211+
concat_attrs = {
2212+
name.lstrip("_") # Private attributes with getter/setter
2213+
for name, value in first_imp.__dict__.items()
2214+
if isinstance(value, np.ndarray)
2215+
}.difference(("event_id", "coord_exp", "eai_exp", "aai_agg"))
2216+
kwargs = {attr: stack_attribute(attr) for attr in concat_attrs}
22142217

22152218
# Get remaining attributes from first impact object in list
22162219
return cls(

climada/engine/test/test_impact_forecast.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,23 @@ def test_no_select(self, impact_forecast, impact_kwargs):
202202
assert imp_fc_select.imp_mat.shape == (0, num_centroids)
203203

204204

205-
@pytest.mark.skip("Concat from base class does not work")
206-
def test_impact_forecast_concat(impact_forecast, member):
205+
def test_impact_forecast_concat(impact_forecast, member, lead_time):
207206
"""Check if Impact.concat works on the derived class"""
208207
impact_fc = ImpactForecast.concat(
209208
[impact_forecast, impact_forecast], reset_event_ids=True
210209
)
211210
npt.assert_array_equal(impact_fc.member, np.concatenate([member, member]))
211+
npt.assert_array_equal(impact_fc.lead_time, np.concatenate([lead_time, lead_time]))
212+
npt.assert_array_equal(
213+
impact_fc.event_id, np.arange(impact_fc.imp_mat.shape[0]) + 1
214+
)
215+
npt.assert_array_equal(impact_fc.event_name, impact_forecast.event_name * 2)
216+
npt.assert_array_equal(
217+
impact_fc.imp_mat.toarray(),
218+
np.vstack(
219+
(impact_forecast.imp_mat.toarray(), impact_forecast.imp_mat.toarray())
220+
),
221+
)
212222

213223

214224
def test_impact_forecast_blocked_methods(impact_forecast):

0 commit comments

Comments
 (0)