2121from abc import ABC , abstractmethod
2222from dataclasses import dataclass , InitVar , field
2323from typing import List , Any , Tuple , Sequence , Dict , Callable
24- from copy import copy
24+ from copy import copy , deepcopy
2525from itertools import repeat
2626import logging
2727
@@ -188,6 +188,7 @@ def plot_shiny(
188188 impact_func_creator : Callable [..., ImpactFuncSet ],
189189 haz_type ,
190190 impf_id ,
191+ input = None ,
191192 ):
192193 """Plot all impact functions with appropriate color coding and event data"""
193194 # Store data to plot
@@ -214,9 +215,15 @@ def plot_shiny(
214215
215216 # Create plot
216217 _ , ax = plt .subplots ()
217- colors = plt .get_cmap ("turbo" )(np .linspace (0 , 1 , self .data .shape [0 ]))
218+
219+ # Plot hazard histogram
220+ # NOTE: Actually requires selection by exposure, but this is not trivial!
221+ if input is not None :
222+ ax2 = ax .twinx ()
223+ ax2 .hist (input .hazard .intensity .data , bins = 40 , color = "grey" , alpha = 0.5 )
218224
219225 # Sort data by final MDR value, then plot
226+ colors = plt .get_cmap ("turbo" )(np .linspace (0 , 1 , self .data .shape [0 ]))
220227 data_plt = sorted (data_plt , key = lambda x : x ["mdr" ][- 1 ], reverse = True )
221228 for idx , data_dict in enumerate (data_plt ):
222229 ax .plot (
@@ -294,7 +301,9 @@ def _iterate_sequential(
294301 ) -> List [SingleEnsembleOptimizerOutput ]:
295302 """Iterate over all samples sequentially"""
296303 return [
297- optimize (self .optimizer_type , input , init_kwargs , optimizer_run_kwargs )
304+ optimize (
305+ self .optimizer_type , input , init_kwargs , deepcopy (optimizer_run_kwargs )
306+ )
298307 for input , init_kwargs in tqdm (
299308 zip (self ._inputs (), self ._opt_init_kwargs ()), total = len (self .samples )
300309 )
@@ -304,6 +313,8 @@ def _iterate_parallel(
304313 self , processes , ** optimizer_run_kwargs
305314 ) -> List [SingleEnsembleOptimizerOutput ]:
306315 """Iterate over all samples in parallel"""
316+ iterations = len (self .samples )
317+ opt_run_kwargs = (deepcopy (optimizer_run_kwargs ) for _ in range (iterations ))
307318 with ProcessPool (nodes = processes ) as pool :
308319 return list (
309320 tqdm (
@@ -312,10 +323,10 @@ def _iterate_parallel(
312323 repeat (self .optimizer_type ),
313324 self ._inputs (),
314325 self ._opt_init_kwargs (),
315- repeat ( optimizer_run_kwargs ) ,
326+ opt_run_kwargs ,
316327 # chunksize=processes,
317328 ),
318- total = len ( self . samples ) ,
329+ total = iterations ,
319330 )
320331 )
321332
0 commit comments