Skip to content

Commit 64149df

Browse files
committed
Simplify code and tests
1 parent 23660de commit 64149df

File tree

4 files changed

+56
-124
lines changed

4 files changed

+56
-124
lines changed

climada/engine/impact_forecast.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,9 @@ def min(self):
228228
ImpactForecast
229229
An ImpactForecast object with the min impact matrix and at_event.
230230
"""
231-
red_imp_mat = sparse.csr_matrix(self.imp_mat.min(axis=0))
231+
red_imp_mat = self.imp_mat.min(axis=0).tocsr()
232232
red_at_event = np.array([red_imp_mat.sum()])
233-
reduced_attrs = self._reduce_attrs("min")
234233
return ImpactForecast(
235-
lead_time=reduced_attrs["lead_time"],
236-
member=reduced_attrs["member"],
237-
event_id=reduced_attrs["event_id"],
238-
event_name=reduced_attrs["event_name"],
239-
date=reduced_attrs["date"],
240-
frequency=reduced_attrs["frequency"],
241234
frequency_unit=self.frequency_unit,
242235
coord_exp=self.coord_exp,
243236
crs=self.crs,
@@ -248,6 +241,7 @@ def min(self):
248241
unit=self.unit,
249242
imp_mat=red_imp_mat,
250243
haz_type=self.haz_type,
244+
**self._reduce_attrs("min"),
251245
)
252246

253247
def max(self):
@@ -264,16 +258,9 @@ def max(self):
264258
ImpactForecast
265259
An ImpactForecast object with the max impact matrix and at_event.
266260
"""
267-
red_imp_mat = sparse.csr_matrix(self.imp_mat.max(axis=0))
261+
red_imp_mat = self.imp_mat.max(axis=0).tocsr()
268262
red_at_event = np.array([red_imp_mat.sum()])
269-
reduced_attrs = self._reduce_attrs("max")
270263
return ImpactForecast(
271-
lead_time=reduced_attrs["lead_time"],
272-
member=reduced_attrs["member"],
273-
event_id=reduced_attrs["event_id"],
274-
event_name=reduced_attrs["event_name"],
275-
date=reduced_attrs["date"],
276-
frequency=reduced_attrs["frequency"],
277264
frequency_unit=self.frequency_unit,
278265
coord_exp=self.coord_exp,
279266
crs=self.crs,
@@ -284,6 +271,7 @@ def max(self):
284271
unit=self.unit,
285272
imp_mat=red_imp_mat,
286273
haz_type=self.haz_type,
274+
**self._reduce_attrs("max"),
287275
)
288276

289277
def mean(self):
@@ -301,14 +289,7 @@ def mean(self):
301289
"""
302290
red_imp_mat = sparse.csr_matrix(self.imp_mat.mean(axis=0))
303291
red_at_event = np.array([red_imp_mat.sum()])
304-
reduced_attrs = self._reduce_attrs("mean")
305292
return ImpactForecast(
306-
lead_time=reduced_attrs["lead_time"],
307-
member=reduced_attrs["member"],
308-
event_id=reduced_attrs["event_id"],
309-
event_name=reduced_attrs["event_name"],
310-
date=reduced_attrs["date"],
311-
frequency=reduced_attrs["frequency"],
312293
frequency_unit=self.frequency_unit,
313294
coord_exp=self.coord_exp,
314295
crs=self.crs,
@@ -319,4 +300,5 @@ def mean(self):
319300
unit=self.unit,
320301
imp_mat=red_imp_mat,
321302
haz_type=self.haz_type,
303+
**self._reduce_attrs("mean"),
322304
)

climada/engine/test/test_impact_forecast.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -167,41 +167,38 @@ def test_impact_forecast_blocked_methods(impact_forecast):
167167
impact_forecast.calc_freq_curve(np.array([10, 50, 100]))
168168

169169

170-
def test_impact_forecast_mean_min_max(impact_forecast):
170+
@pytest.fixture
171+
def impact_forecast_stats(impact_kwargs, lead_time, member):
172+
max_index = 4
173+
for key, val in impact_kwargs.items():
174+
if isinstance(val, (np.ndarray, list)):
175+
impact_kwargs[key] = val[:max_index]
176+
elif isinstance(val, csr_matrix):
177+
impact_kwargs[key] = val[:max_index, :]
178+
impact_kwargs["imp_mat"] = csr_matrix([[1, 0], [0, 1], [3, 2], [2, 3]])
179+
impact_kwargs["at_event"] = np.array([1, 1, 5, 5])
180+
return ImpactForecast(
181+
lead_time=lead_time[:max_index], member=member[:max_index], **impact_kwargs
182+
)
183+
184+
185+
@pytest.mark.parametrize("attr", ["min", "mean", "max"])
186+
def test_impact_forecast_min_mean_max(impact_forecast_stats, attr):
171187
"""Check mean, min, and max methods for ImpactForecast"""
172-
imp_fcst_mean = impact_forecast.mean()
173-
imp_fcst_min = impact_forecast.min()
174-
imp_fcst_max = impact_forecast.max()
188+
imp_fc_reduced = getattr(impact_forecast_stats, attr)()
175189

176190
# assert imp_mat
177191
npt.assert_array_equal(
178-
imp_fcst_mean.imp_mat.todense(), impact_forecast.imp_mat.todense().mean(axis=0)
192+
imp_fc_reduced.imp_mat.todense(),
193+
getattr(impact_forecast_stats.imp_mat.todense(), attr)(axis=0),
179194
)
180-
npt.assert_array_equal(imp_fcst_min.imp_mat.todense(), np.array([[0, 0]]))
181-
npt.assert_array_equal(imp_fcst_max.imp_mat.todense(), np.array([[31, 31]]))
182-
# assert at_event
183-
npt.assert_array_equal(
184-
imp_fcst_mean.at_event, impact_forecast.at_event.mean()
185-
) # 134/6
186-
npt.assert_array_equal(imp_fcst_min.at_event, impact_forecast.at_event.min())
187-
npt.assert_array_equal(imp_fcst_max.at_event, impact_forecast.at_event.max())
195+
at_event_expected = {"min": [0], "mean": [3], "max": [6]}
196+
npt.assert_array_equal(imp_fc_reduced.at_event, at_event_expected[attr])
188197

189198
# check that attributes where reduced correctly
190-
assert np.isnat(imp_fcst_mean.lead_time[0])
191-
assert np.isnat(imp_fcst_min.lead_time[0])
192-
assert np.isnat(imp_fcst_max.lead_time[0])
193-
assert imp_fcst_mean.member[0] == -1
194-
assert imp_fcst_min.member[0] == -1
195-
assert imp_fcst_max.member[0] == -1
196-
assert imp_fcst_mean.event_name[0] == "mean"
197-
assert imp_fcst_min.event_name[0] == "min"
198-
assert imp_fcst_max.event_name[0] == "max"
199-
assert imp_fcst_mean.event_id[0] == 0
200-
assert imp_fcst_min.event_id[0] == 0
201-
assert imp_fcst_max.event_id[0] == 0
202-
assert imp_fcst_mean.frequency == 1
203-
assert imp_fcst_min.frequency == 1
204-
assert imp_fcst_max.frequency == 1
205-
assert imp_fcst_mean.date == 0
206-
assert imp_fcst_min.date == 0
207-
assert imp_fcst_max.date == 0
199+
npt.assert_array_equal(np.isnat(imp_fc_reduced.lead_time), [True])
200+
npt.assert_array_equal(imp_fc_reduced.member, [-1])
201+
npt.assert_array_equal(imp_fc_reduced.event_name, [attr])
202+
npt.assert_array_equal(imp_fc_reduced.event_id, [0])
203+
npt.assert_array_equal(imp_fc_reduced.frequency, [1])
204+
npt.assert_array_equal(imp_fc_reduced.date, [0])

climada/hazard/forecast.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,17 @@ def min(self):
149149
HazardForecast
150150
A HazardForecast object with the min intensity and fraction.
151151
"""
152-
red_intensity = sparse.csr_matrix(self.intensity.min(axis=0))
153-
red_fraction = sparse.csr_matrix(self.fraction.min(axis=0))
154-
reduced_attrs = self._reduce_attrs("min")
152+
red_intensity = self.intensity.min(axis=0).tocsr()
153+
red_fraction = self.fraction.min(axis=0).tocsr()
155154
return HazardForecast(
156-
lead_time=reduced_attrs["lead_time"],
157-
member=reduced_attrs["member"],
158155
haz_type=self.haz_type,
159156
pool=self.pool,
160157
units=self.units,
161158
centroids=self.centroids,
162-
event_id=reduced_attrs["event_id"],
163-
frequency=reduced_attrs["frequency"],
164159
frequency_unit=self.frequency_unit,
165-
event_name=reduced_attrs["event_name"],
166-
date=reduced_attrs["date"],
167-
orig=reduced_attrs["orig"],
168160
intensity=red_intensity,
169161
fraction=red_fraction,
162+
**self._reduce_attrs("min"),
170163
)
171164

172165
def max(self):
@@ -183,24 +176,17 @@ def max(self):
183176
HazardForecast
184177
A HazardForecast object with the min intensity and fraction.
185178
"""
186-
red_intensity = sparse.csr_matrix(self.intensity.max(axis=0))
187-
red_fraction = sparse.csr_matrix(self.fraction.max(axis=0))
188-
reduced_attrs = self._reduce_attrs("max")
179+
red_intensity = self.intensity.max(axis=0).tocsr()
180+
red_fraction = self.fraction.max(axis=0).tocsr()
189181
return HazardForecast(
190-
lead_time=reduced_attrs["lead_time"],
191-
member=reduced_attrs["member"],
192182
haz_type=self.haz_type,
193183
pool=self.pool,
194184
units=self.units,
195185
centroids=self.centroids,
196-
event_id=reduced_attrs["event_id"],
197-
frequency=reduced_attrs["frequency"],
198186
frequency_unit=self.frequency_unit,
199-
event_name=reduced_attrs["event_name"],
200-
date=reduced_attrs["date"],
201-
orig=reduced_attrs["orig"],
202187
intensity=red_intensity,
203188
fraction=red_fraction,
189+
**self._reduce_attrs("max"),
204190
)
205191

206192
def mean(self):
@@ -218,20 +204,13 @@ def mean(self):
218204
"""
219205
red_intensity = sparse.csr_matrix(self.intensity.mean(axis=0))
220206
red_fraction = sparse.csr_matrix(self.fraction.mean(axis=0))
221-
reduced_attrs = self._reduce_attrs("mean")
222207
return HazardForecast(
223-
lead_time=reduced_attrs["lead_time"],
224-
member=reduced_attrs["member"],
225208
haz_type=self.haz_type,
226209
pool=self.pool,
227210
units=self.units,
228211
centroids=self.centroids,
229-
event_id=reduced_attrs["event_id"],
230-
frequency=reduced_attrs["frequency"],
231212
frequency_unit=self.frequency_unit,
232-
event_name=reduced_attrs["event_name"],
233-
date=reduced_attrs["date"],
234-
orig=reduced_attrs["orig"],
235213
intensity=red_intensity,
236214
fraction=red_fraction,
215+
**self._reduce_attrs("mean"),
237216
)

climada/hazard/test/test_forecast.py

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -166,52 +166,26 @@ def test_write_read_hazard_forecast(haz_fc, tmp_path):
166166
npt.assert_array_equal(haz_fc.__dict__[key], haz_fc_read.__dict__[key])
167167

168168

169-
def test_hazard_forecast_mean_min_max(haz_fc):
169+
@pytest.mark.parametrize("attr", ["min", "mean", "max"])
170+
def test_hazard_forecast_mean_min_max(haz_fc, attr):
170171
"""Check mean, min, and max methods for ImpactForecast"""
171-
haz_fcst_mean = haz_fc.mean()
172-
haz_fcst_min = haz_fc.min()
173-
haz_fcst_max = haz_fc.max()
172+
haz_fcst_reduced = getattr(haz_fc, attr)()
174173

175-
# assert intensity
174+
# Assert sparse matrices
176175
npt.assert_array_equal(
177-
haz_fcst_mean.intensity.todense(), haz_fc.intensity.todense().mean(axis=0)
176+
haz_fcst_reduced.intensity.todense(),
177+
getattr(haz_fc.intensity.todense(), attr)(axis=0),
178178
)
179179
npt.assert_array_equal(
180-
haz_fcst_min.intensity.todense(), haz_fc.intensity.todense().min(axis=0)
181-
)
182-
npt.assert_array_equal(
183-
haz_fcst_max.intensity.todense(), haz_fc.intensity.todense().max(axis=0)
184-
)
185-
# assert fraction
186-
npt.assert_array_equal(
187-
haz_fcst_mean.fraction.todense(), haz_fc.fraction.todense().mean(axis=0)
188-
)
189-
npt.assert_array_equal(
190-
haz_fcst_min.fraction.todense(), haz_fc.fraction.todense().min(axis=0)
191-
)
192-
npt.assert_array_equal(
193-
haz_fcst_max.fraction.todense(), haz_fc.fraction.todense().max(axis=0)
180+
haz_fcst_reduced.fraction.todense(),
181+
getattr(haz_fc.fraction.todense(), attr)(axis=0),
194182
)
195183

196-
# check that attributes where reduced correctly
197-
assert np.isnat(haz_fcst_mean.lead_time[0])
198-
assert np.isnat(haz_fcst_min.lead_time[0])
199-
assert np.isnat(haz_fcst_max.lead_time[0])
200-
assert haz_fcst_mean.member[0] == -1
201-
assert haz_fcst_min.member[0] == -1
202-
assert haz_fcst_max.member[0] == -1
203-
assert haz_fcst_mean.event_name[0] == "mean"
204-
assert haz_fcst_min.event_name[0] == "min"
205-
assert haz_fcst_max.event_name[0] == "max"
206-
assert haz_fcst_mean.event_id[0] == 0
207-
assert haz_fcst_min.event_id[0] == 0
208-
assert haz_fcst_max.event_id[0] == 0
209-
assert haz_fcst_mean.frequency == 1
210-
assert haz_fcst_min.frequency == 1
211-
assert haz_fcst_max.frequency == 1
212-
assert haz_fcst_mean.date == 0
213-
assert haz_fcst_min.date == 0
214-
assert haz_fcst_max.date == 0
215-
assert np.all(haz_fcst_mean.orig)
216-
assert np.all(haz_fcst_min.orig)
217-
assert np.all(haz_fcst_max.orig)
184+
# Check that attributes where reduced correctly
185+
npt.assert_array_equal(np.isnat(haz_fcst_reduced.lead_time), [True])
186+
npt.assert_array_equal(haz_fcst_reduced.member, [-1])
187+
npt.assert_array_equal(haz_fcst_reduced.event_name, [attr])
188+
npt.assert_array_equal(haz_fcst_reduced.event_id, [0])
189+
npt.assert_array_equal(haz_fcst_reduced.frequency, [1])
190+
npt.assert_array_equal(haz_fcst_reduced.date, [0])
191+
npt.assert_array_equal(haz_fcst_reduced.orig, [True])

0 commit comments

Comments
 (0)