2323from typing import Optional , List , Mapping , Any , Tuple , Union , Sequence , Dict
2424from copy import copy , deepcopy
2525from pathlib import Path
26+ from itertools import repeat
2627
2728import numpy as np
2829from numpy .random import default_rng
2930import pandas as pd
31+ from pathos .multiprocessing import ProcessPool
32+ from tqdm import tqdm
3033
3134from ...engine .unsequa .input_var import InputVar
3235from .base import Optimizer , Output , Input
@@ -46,6 +49,34 @@ def sample_data(data: pd.DataFrame, sample: List[Tuple[int, int]]):
4649 return data_sampled
4750
4851
52+ def event_info_from_input (input : Input ) -> Dict [str , Any ]:
53+ """Get information on the event(s) for which we calibrated"""
54+ # Get region and event IDs
55+ data = input .data .dropna (axis = "columns" , how = "all" ).dropna (axis = "index" , how = "all" )
56+ event_ids = data .index
57+ region_ids = data .columns
58+
59+ # Get event name
60+ event_names = input .hazard .select (event_id = event_ids .to_list ()).event_name
61+
62+ # Return data
63+ return {
64+ "event_id" : event_ids ,
65+ "region_id" : region_ids ,
66+ "event_name" : event_names ,
67+ }
68+
69+
70+ def optimize (optimizer_type , input , opt_init_kwargs , opt_run_kwargs ):
71+ opt = optimizer_type (input , ** opt_init_kwargs )
72+ out = opt .run (** opt_run_kwargs )
73+ return SingleEnsembleOptimizerOutput (
74+ params = out .params ,
75+ target = out .target ,
76+ event_info = event_info_from_input (input ),
77+ )
78+
79+
4980@dataclass
5081class SingleEnsembleOptimizerOutput (Output ):
5182 """Output for a single member of an ensemble optimizer
@@ -94,40 +125,34 @@ def to_input_var(self, impact_func_creator, **impfset_kwargs):
94125 ]
95126 return InputVar .impfset (impf_set_list , ** impfset_kwargs )
96127
97- # Build MultiIndex DataFrame
98- # data = pd.DataFrame(
99- # columns=pd.MultiIndex.from_tuples(
100- # [("Parameters", p_name) for p_name in outputs[0].params.keys()]
101- # )
102- # )
103-
104- # Insert Parameters
105- # params = pd.DataFrame.from_records([out.params for out in outputs])
106- # for p_name in params.columns:
107- # data["Parameters", p_name] = params[p_name]
108-
109- # Insert
110-
111- # return cls(data=pd.DataFrame.from_records([out.params for out in outputs]))
112-
113128
114129@dataclass
115130class EnsembleOptimizer (ABC ):
116131 """"""
117132
118133 input : Input
119134 optimizer_type : Any
120- optimizer_init_kwargs : Mapping [str , Any ] = field (default_factory = dict )
135+ optimizer_init_kwargs : Dict [str , Any ] = field (default_factory = dict )
121136 samples : List [List [Tuple [int , int ]]] = field (init = False )
122137
123138 def __post_init__ (self ):
124139 """"""
125140 if self .samples is None :
126141 raise RuntimeError ("Samples must be set!" )
127142
128- def run (self , ** optimizer_run_kwargs ) -> EnsembleOptimizerOutput :
143+ def run (self , processes = 1 , ** optimizer_run_kwargs ) -> EnsembleOptimizerOutput :
144+ if processes == 1 :
145+ outputs = self ._iterate_sequential (** optimizer_run_kwargs )
146+ else :
147+ outputs = self ._iterate_parallel (processes , ** optimizer_run_kwargs )
148+ return EnsembleOptimizerOutput .from_outputs (outputs )
149+
150+ def _iterate_sequential (
151+ self , ** optimizer_run_kwargs
152+ ) -> List [SingleEnsembleOptimizerOutput ]:
153+ """Iterate over all samples sequentially"""
129154 outputs = []
130- for idx , sample in enumerate (self .samples ):
155+ for idx , sample in enumerate (tqdm ( self .samples ) ):
131156 input = self .input_from_sample (sample )
132157
133158 # Run optimizer
@@ -138,13 +163,37 @@ def run(self, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
138163 out = SingleEnsembleOptimizerOutput (
139164 params = out .params ,
140165 target = out .target ,
141- event_info = self . event_info_from_input (input ),
166+ event_info = event_info_from_input (input ),
142167 )
143168
144- print (f"Ensemble: { idx } , Params: { out .params } " )
145169 outputs .append (out )
146170
147- return EnsembleOptimizerOutput .from_outputs (outputs )
171+ return outputs
172+
173+ def _iterate_parallel (
174+ self , processes , ** optimizer_run_kwargs
175+ ) -> List [SingleEnsembleOptimizerOutput ]:
176+ """Iterate over all samples in parallel"""
177+ inputs = (self .input_from_sample (sample ) for sample in self .samples )
178+ opt_init_kwargs = (
179+ self ._update_init_kwargs (self .optimizer_init_kwargs , idx )
180+ for idx in range (len (self .samples ))
181+ )
182+
183+ with ProcessPool (nodes = processes ) as pool :
184+ return list (
185+ tqdm (
186+ pool .imap (
187+ optimize ,
188+ repeat (self .optimizer_type ),
189+ inputs ,
190+ opt_init_kwargs ,
191+ repeat (optimizer_run_kwargs ),
192+ # chunksize=processes,
193+ ),
194+ total = len (self .samples ),
195+ )
196+ )
148197
149198 @abstractmethod
150199 def input_from_sample (self , sample : List [Tuple [int , int ]]) -> Input :
@@ -159,25 +208,6 @@ def _update_init_kwargs(
159208 kwargs ["random_state" ] = kwargs ["random_state" ] + iteration
160209 return kwargs
161210
162- def event_info_from_input (self , input : Input ) -> Dict [str , Any ]:
163- """Get information on the event(s) for which we calibrated"""
164- # Get region and event IDs
165- data = input .data .dropna (axis = "columns" , how = "all" ).dropna (
166- axis = "index" , how = "all"
167- )
168- event_ids = data .index
169- region_ids = data .columns
170-
171- # Get event name
172- event_names = input .hazard .select (event_id = event_ids .to_list ()).event_name
173-
174- # Return data
175- return {
176- "event_id" : event_ids ,
177- "region_id" : region_ids ,
178- "event_name" : event_names ,
179- }
180-
181211
182212@dataclass
183213class AverageEnsembleOptimizer (EnsembleOptimizer ):
0 commit comments