|
20 | 20 |
|
21 | 21 | from abc import ABC, abstractmethod |
22 | 22 | from dataclasses import dataclass, InitVar, field |
23 | | -from typing import List, Any, Tuple, Sequence, Dict, Callable |
| 23 | +from typing import List, Any, Tuple, Sequence, Dict, Callable, Optional |
24 | 24 | from copy import copy, deepcopy |
25 | 25 | from itertools import repeat |
26 | 26 | import logging |
@@ -260,32 +260,40 @@ def plot_category( |
260 | 260 | input=None, |
261 | 261 | category=None, |
262 | 262 | category_col_dict=None, |
263 | | - **impf_set_plot_kwargs |
| 263 | + **impf_set_plot_kwargs, |
264 | 264 | ): |
265 | | - |
266 | 265 | """Plot all impact functions with appropriate color coding according to a category""" |
267 | 266 | impf_set_arr = np.array(self._to_impf_sets(impact_func_creator)) |
268 | 267 |
|
269 | 268 | if category_col_dict is None: |
270 | 269 | unique_categories = self.data[("Event", category)].unique() |
271 | 270 | print(unique_categories) |
272 | | - unique_colors = plt.get_cmap("Set1")(np.linspace(0, 1, len(unique_categories))) |
| 271 | + unique_colors = plt.get_cmap("Set1")( |
| 272 | + np.linspace(0, 1, len(unique_categories)) |
| 273 | + ) |
273 | 274 | else: |
274 | 275 | unique_categories = list(category_col_dict.keys()) |
275 | 276 | unique_colors = list(category_col_dict.values()) |
276 | 277 |
|
277 | | - fig,ax = plt.subplots() |
278 | | - for sel_category,color in zip(unique_categories,unique_colors): |
| 278 | + fig, ax = plt.subplots() |
| 279 | + for sel_category, color in zip(unique_categories, unique_colors): |
279 | 280 | cat_idx = self.data[("Event", category)] == sel_category |
280 | 281 |
|
281 | | - for i,impf_set in enumerate(impf_set_arr[cat_idx]): |
| 282 | + for i, impf_set in enumerate(impf_set_arr[cat_idx]): |
282 | 283 | impf = impf_set.get_func(haz_type=haz_type, fun_id=impf_id) |
283 | | - label = f"{sel_category}, n={cat_idx.sum()} "if i == 0 else None |
284 | | - ax.plot(impf.intensity, impf.paa * impf.mdd, **impf_set_plot_kwargs, |
285 | | - color = color,label=label) |
| 284 | + label = f"{sel_category}, n={cat_idx.sum()} " if i == 0 else None |
| 285 | + ax.plot( |
| 286 | + impf.intensity, |
| 287 | + impf.paa * impf.mdd, |
| 288 | + **impf_set_plot_kwargs, |
| 289 | + color=color, |
| 290 | + label=label, |
| 291 | + ) |
286 | 292 | # impf.mdr.plot(axis=ax, **impf_set_plot_kwargs)#, label=sel_category) |
287 | 293 |
|
288 | | - ax.legend(title=category,bbox_to_anchor=(1.05, 1),loc='upper left',frameon=False) |
| 294 | + ax.legend( |
| 295 | + title=category, bbox_to_anchor=(1.05, 1), loc="upper left", frameon=False |
| 296 | + ) |
289 | 297 | # Cosmetics |
290 | 298 | ax.set_xlabel(f"Intensity [{impf.intensity_unit}]") |
291 | 299 | ax.set_ylabel("Impact") |
@@ -423,11 +431,27 @@ def input_from_sample(self, sample: List[Tuple[int, int]]): |
423 | 431 | class TragedyEnsembleOptimizer(EnsembleOptimizer): |
424 | 432 | """""" |
425 | 433 |
|
426 | | - def __post_init__(self): |
| 434 | + ensemble_size: InitVar[Optional[int]] = None |
| 435 | + random_state: InitVar[int] = 1 |
| 436 | + |
| 437 | + def __post_init__(self, ensemble_size, random_state): |
427 | 438 | """Create the single samples""" |
428 | 439 | notna_idx = np.argwhere(self.input.data.notna().to_numpy()) |
429 | 440 | self.samples = notna_idx[:, np.newaxis].tolist() # Must extend by one dimension |
430 | 441 |
|
| 442 | + # Subselection for a given ensemble size |
| 443 | + if ensemble_size is not None: |
| 444 | + if ensemble_size < 1: |
| 445 | + raise ValueError("Ensemble size must be >=1") |
| 446 | + if ensemble_size > len(self.samples): |
| 447 | + raise ValueError( |
| 448 | + "Ensemble size must be smaller than maximum number of samples " |
| 449 | + f"(here: {len(self.samples)})" |
| 450 | + ) |
| 451 | + |
| 452 | + rng = default_rng(random_state) |
| 453 | + self.samples = rng.choice(self.samples, ensemble_size, replace=False) |
| 454 | + |
431 | 455 | return super().__post_init__() |
432 | 456 |
|
433 | 457 | def input_from_sample(self, sample: List[Tuple[int, int]]): |
|
0 commit comments