3535from tqdm import tqdm
3636
3737from ...engine .unsequa .input_var import InputVar
38- from ...entity .impact_funcs import ImpactFuncSet
38+ from ...entity .impact_funcs import ImpactFunc , ImpactFuncSet
3939from ..coordinates import country_to_iso
4040from .base import Input , Optimizer , Output
4141
@@ -46,23 +46,23 @@ def sample_data(data: pd.DataFrame, sample: list[tuple[int, int]]):
4646 """
4747 Return a DataFrame containing only the sampled values from the input data.
4848
49- The resulting data frame has the same shape and indices ad `data` and is filled with
50- NaNs, except for the row and column indices specified by `sample`.
49+ The resulting data frame has the same shape and indices ad `` data`` and is filled
50+ with NaNs, except for the row and column indices specified by `` sample` `.
5151
5252 Parameters
5353 ----------
5454 data : pandas.DataFrame
5555 The input DataFrame from which values will be sampled.
5656 sample : list of tuple of int
5757 A list of (row, column) index pairs indicating which positions
58- to copy from `data` into the returned DataFrame.
58+ to copy from `` data` ` into the returned DataFrame.
5959
6060 Returns
6161 -------
6262 pandas.DataFrame
63- A DataFrame of the same shape as `data` with NaNs in all positions
64- except those specified in `sample`, which contain the corresponding values
65- from `data`.
63+ A DataFrame of the same shape as `` data` ` with NaNs in all positions
64+ except those specified in `` sample` `, which contain the corresponding values
65+ from `` data` `.
6666 """
6767 # Create all-NaN data
6868 data_sampled = pd .DataFrame (np .nan , columns = data .columns , index = data .index )
@@ -82,7 +82,7 @@ def event_info_from_input(inp: Input) -> dict[str, Any]:
8282 Returns
8383 -------
8484 dict
85- With keys `event_id`, `region_id`, `event_name`
85+ With keys `` event_id`` , `` region_id`` , `` event_name` `
8686 """
8787 # Get region and event IDs
8888 data = inp .data .dropna (axis = "columns" , how = "all" ).dropna (axis = "index" , how = "all" )
@@ -161,6 +161,11 @@ class EnsembleOptimizerOutput:
161161 @classmethod
162162 def from_outputs (cls , outputs : Sequence [SingleEnsembleOptimizerOutput ]):
163163 """Build data from a list of outputs"""
164+ # Support empty sequences
165+ if not outputs :
166+ return cls (data = pd .DataFrame ())
167+
168+ # Derive column names
164169 cols = pd .MultiIndex .from_tuples (
165170 [("Parameters" , p_name ) for p_name in outputs [0 ].params .keys ()]
166171 + [("Event" , p_name ) for p_name in outputs [0 ].event_info ]
@@ -216,7 +221,11 @@ def to_input_var(
216221 def plot (
217222 self , impact_func_creator : Callable [..., ImpactFuncSet ], ** impf_set_plot_kwargs
218223 ):
219- """Plot all impact functions into the same plot"""
224+ """Plot all impact functions into the same plot
225+
226+ This uses the basic plot functions of
227+ :py:class:`~climada.entity.impact_funcs.base.ImpactFuncSet`.
228+ """
220229 impf_set_list = self ._to_impf_sets (impact_func_creator )
221230
222231 # Create a single plot for the overall layout, then continue plotting into it
@@ -243,9 +252,11 @@ def plot(
243252 def plot_shiny (
244253 self ,
245254 impact_func_creator : Callable [..., ImpactFuncSet ],
246- haz_type ,
247- impf_id ,
248- inp = None ,
255+ haz_type : str ,
256+ impf_id : int ,
257+ inp : Input | None = None ,
258+ impf_plot_kwargs : Mapping [str , Any ] | None = None ,
259+ hazard_plot_kwargs : Mapping [str , Any ] | None = None ,
249260 ):
250261 """Plot all impact functions with appropriate color coding and event data"""
251262 # Store data to plot
@@ -254,6 +265,12 @@ def plot_shiny(
254265 impf = impact_func_creator (** row ["Parameters" ]).get_func (
255266 haz_type = haz_type , fun_id = impf_id
256267 )
268+ if not isinstance (impf , ImpactFunc ):
269+ raise ValueError (
270+ f"Cannot find a unique impact function for haz_type: { haz_type } , "
271+ f"impf_id: { impf_id } "
272+ )
273+
257274 region_id = country_to_iso (row [("Event" , "region_id" )])
258275 event_name = row [("Event" , "event_name" )]
259276 event_id = row [("Event" , "event_id" )]
@@ -277,17 +294,24 @@ def plot_shiny(
277294 # NOTE: Actually requires selection by exposure, but this is not trivial!
278295 if inp is not None :
279296 ax2 = ax .twinx ()
280- ax2 .hist (inp .hazard .intensity .data , bins = 40 , color = "grey" , alpha = 0.5 )
297+ hist_kwargs = {"bins" : 40 , "color" : "grey" , "alpha" : 0.5 }
298+ if hazard_plot_kwargs is not None :
299+ hist_kwargs .update (hazard_plot_kwargs )
300+ ax2 .hist (inp .hazard .intensity .data , ** hist_kwargs )
301+ elif hazard_plot_kwargs is not None :
302+ LOGGER .warning ("No 'inp' parameter provided. Ignoring 'hazard_plot_kwargs'" )
281303
282304 # Sort data by final MDR value, then plot
283305 colors = plt .get_cmap ("turbo" )(np .linspace (0 , 1 , self .data .shape [0 ]))
284306 data_plt = sorted (data_plt , key = lambda x : x ["mdr" ][- 1 ], reverse = True )
307+ impf_plot_kwargs = impf_plot_kwargs if impf_plot_kwargs is not None else {}
285308 for idx , data_dict in enumerate (data_plt ):
286309 ax .plot (
287310 data_dict ["intensity" ],
288311 data_dict ["mdr" ],
289312 label = data_dict ["label" ],
290313 color = colors [idx ],
314+ ** impf_plot_kwargs ,
291315 )
292316
293317 # Cosmetics
@@ -312,40 +336,61 @@ def plot_shiny(
312336 def plot_category (
313337 self ,
314338 impact_func_creator : Callable [..., ImpactFuncSet ],
315- haz_type ,
316- impf_id ,
317- category = None ,
318- category_col_dict = None ,
339+ haz_type : str ,
340+ impf_id : int ,
341+ category : str ,
342+ category_colors : Mapping [ str , str | tuple ] | None = None ,
319343 ** impf_set_plot_kwargs ,
320344 ):
321- """Plot all impact functions with appropriate color coding according to a category"""
345+ """Plot impact functions with coloring according to a certain category
346+
347+ Parameters
348+ ----------
349+ impact_func_creator : Callable
350+ A function taking parameters and returning an
351+ :py:class:`~climada.entity.impact_funcs.base.ImpactFuncSet`.
352+ haz_type : str
353+ The hazard type of the impact function to plot.
354+ impf_id : int
355+ The ID of the impact function to plot.
356+ category : str
357+ The event information on which to categorize (can be ``"region_id"``,
358+ ``"event_id"``, or ``"event_name"``)
359+ category_colors : dict(str, str or tuple), optional
360+ Specify which categories to plot (keys) and what colors to use for them
361+ (values). If ``None``, will categorize for unique values in the ``category``
362+ column and color automatically.
363+ """
322364 impf_set_arr = np .array (self ._to_impf_sets (impact_func_creator ))
323365
324- if category_col_dict is None :
366+ if category_colors is None :
325367 unique_categories = self .data [("Event" , category )].unique ()
326- print (unique_categories )
327- unique_colors = plt .get_cmap ("Set1" )(
368+ unique_colors = plt .get_cmap ("turbo" )(
328369 np .linspace (0 , 1 , len (unique_categories ))
329370 )
330371 else :
331- unique_categories = list ( category_col_dict .keys () )
332- unique_colors = list ( category_col_dict .values () )
372+ unique_categories = category_colors .keys ()
373+ unique_colors = category_colors .values ()
333374
334- fig , ax = plt .subplots ()
375+ _ , ax = plt .subplots ()
335376 for sel_category , color in zip (unique_categories , unique_colors ):
336377 cat_idx = self .data [("Event" , category )] == sel_category
337378
338379 for i , impf_set in enumerate (impf_set_arr [cat_idx ]):
339380 impf = impf_set .get_func (haz_type = haz_type , fun_id = impf_id )
381+ if not isinstance (impf , ImpactFunc ):
382+ raise ValueError (
383+ "Cannot find a unique impact function for haz_type: "
384+ f"{ haz_type } , impf_id: { impf_id } "
385+ )
340386 label = f"{ sel_category } , n={ cat_idx .sum ()} " if i == 0 else None
341387 ax .plot (
342388 impf .intensity ,
343389 impf .paa * impf .mdd ,
344- ** impf_set_plot_kwargs ,
345390 color = color ,
346391 label = label ,
392+ ** impf_set_plot_kwargs ,
347393 )
348- # impf.mdr.plot(axis=ax, **impf_set_plot_kwargs)#, label=sel_category)
349394
350395 ax .legend (
351396 title = category , bbox_to_anchor = (1.05 , 1 ), loc = "upper left" , frameon = False
@@ -375,12 +420,7 @@ class EnsembleOptimizer(ABC):
375420 input : Input
376421 optimizer_type : type [Optimizer ]
377422 optimizer_init_kwargs : dict [str , Any ] = field (default_factory = dict )
378- samples : list [list [tuple [int , int ]]] = field (init = False )
379-
380- def __post_init__ (self , ** __ ):
381- """"""
382- if self .samples is None :
383- raise RuntimeError ("Samples must be set!" )
423+ samples : list [list [tuple [int , int ]]] = field (init = False , default_factory = list )
384424
385425 def run (self , processes = 1 , ** optimizer_run_kwargs ) -> EnsembleOptimizerOutput :
386426 """Execute the ensemble optimization
@@ -392,7 +432,7 @@ def run(self, processes=1, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
392432 1 (no parallelization)
393433 optimizer_run_kwargs
394434 Additional keywords arguments for the
395- :py:func`~climada.util.calibrate.base.Optimizer.run` method of the
435+ :py:func: `~climada.util.calibrate.base.Optimizer.run` method of the
396436 particular optimizer used.
397437 """
398438 if processes == 1 :
@@ -496,8 +536,6 @@ def __post_init__(self, sample_fraction, ensemble_size, random_state):
496536 for _ in range (ensemble_size )
497537 ]
498538
499- return super ().__post_init__ ()
500-
501539 def input_from_sample (self , sample : list [tuple [int , int ]]):
502540 """Shallow-copy the input and update the data"""
503541 input = copy (self .input ) # NOTE: Shallow copy!
@@ -512,8 +550,8 @@ class TragedyEnsembleOptimizer(EnsembleOptimizer):
512550 Attributes
513551 ----------
514552 ensemble_size : int, optional
515- The number of calibration tasks to perform. Defaults to `None`, which means one
516- for each data point. Must be smaller or equal to the number of data points.
553+ The number of calibration tasks to perform. Defaults to `` None`` , which means
554+ one for each data point. Must be smaller or equal to the number of data points.
517555 random_state : int
518556 The seed for the random number generator selecting the samples
519557 """
@@ -539,8 +577,6 @@ def __post_init__(self, ensemble_size, random_state):
539577 rng = default_rng (random_state )
540578 self .samples = rng .choice (self .samples , ensemble_size , replace = False )
541579
542- return super ().__post_init__ ()
543-
544580 def input_from_sample (self , sample : list [tuple [int , int ]]):
545581 """Subselect all input"""
546582 # Data
0 commit comments