Skip to content

Commit 3947bef

Browse files
committed
Add option to reduce tragedy ensemble size
1 parent 951149b commit 3947bef

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

climada/util/calibrate/cross_calibrate.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from abc import ABC, abstractmethod
2222
from dataclasses import dataclass, InitVar, field
23-
from typing import List, Any, Tuple, Sequence, Dict, Callable
23+
from typing import List, Any, Tuple, Sequence, Dict, Callable, Optional
2424
from copy import copy, deepcopy
2525
from itertools import repeat
2626
import logging
@@ -260,32 +260,40 @@ def plot_category(
260260
input=None,
261261
category=None,
262262
category_col_dict=None,
263-
**impf_set_plot_kwargs
263+
**impf_set_plot_kwargs,
264264
):
265-
266265
"""Plot all impact functions with appropriate color coding according to a category"""
267266
impf_set_arr = np.array(self._to_impf_sets(impact_func_creator))
268267

269268
if category_col_dict is None:
270269
unique_categories = self.data[("Event", category)].unique()
271270
print(unique_categories)
272-
unique_colors = plt.get_cmap("Set1")(np.linspace(0, 1, len(unique_categories)))
271+
unique_colors = plt.get_cmap("Set1")(
272+
np.linspace(0, 1, len(unique_categories))
273+
)
273274
else:
274275
unique_categories = list(category_col_dict.keys())
275276
unique_colors = list(category_col_dict.values())
276277

277-
fig,ax = plt.subplots()
278-
for sel_category,color in zip(unique_categories,unique_colors):
278+
fig, ax = plt.subplots()
279+
for sel_category, color in zip(unique_categories, unique_colors):
279280
cat_idx = self.data[("Event", category)] == sel_category
280281

281-
for i,impf_set in enumerate(impf_set_arr[cat_idx]):
282+
for i, impf_set in enumerate(impf_set_arr[cat_idx]):
282283
impf = impf_set.get_func(haz_type=haz_type, fun_id=impf_id)
283-
label = f"{sel_category}, n={cat_idx.sum()} "if i == 0 else None
284-
ax.plot(impf.intensity, impf.paa * impf.mdd, **impf_set_plot_kwargs,
285-
color = color,label=label)
284+
label = f"{sel_category}, n={cat_idx.sum()} " if i == 0 else None
285+
ax.plot(
286+
impf.intensity,
287+
impf.paa * impf.mdd,
288+
**impf_set_plot_kwargs,
289+
color=color,
290+
label=label,
291+
)
286292
# impf.mdr.plot(axis=ax, **impf_set_plot_kwargs)#, label=sel_category)
287293

288-
ax.legend(title=category,bbox_to_anchor=(1.05, 1),loc='upper left',frameon=False)
294+
ax.legend(
295+
title=category, bbox_to_anchor=(1.05, 1), loc="upper left", frameon=False
296+
)
289297
# Cosmetics
290298
ax.set_xlabel(f"Intensity [{impf.intensity_unit}]")
291299
ax.set_ylabel("Impact")
@@ -423,11 +431,27 @@ def input_from_sample(self, sample: List[Tuple[int, int]]):
423431
class TragedyEnsembleOptimizer(EnsembleOptimizer):
424432
""""""
425433

426-
def __post_init__(self):
434+
ensemble_size: InitVar[Optional[int]] = None
435+
random_state: InitVar[int] = 1
436+
437+
def __post_init__(self, ensemble_size, random_state):
427438
"""Create the single samples"""
428439
notna_idx = np.argwhere(self.input.data.notna().to_numpy())
429440
self.samples = notna_idx[:, np.newaxis].tolist() # Must extend by one dimension
430441

442+
# Subselection for a given ensemble size
443+
if ensemble_size is not None:
444+
if ensemble_size < 1:
445+
raise ValueError("Ensemble size must be >=1")
446+
if ensemble_size > len(self.samples):
447+
raise ValueError(
448+
"Ensemble size must be smaller than maximum number of samples "
449+
f"(here: {len(self.samples)})"
450+
)
451+
452+
rng = default_rng(random_state)
453+
self.samples = rng.choice(self.samples, ensemble_size, replace=False)
454+
431455
return super().__post_init__()
432456

433457
def input_from_sample(self, sample: List[Tuple[int, int]]):

0 commit comments

Comments
 (0)