Skip to content

Commit 3abb8bd

Browse files
committed
Make EnsembleOptimizer run in parallel
1 parent 149167f commit 3abb8bd

File tree

1 file changed

+71
-41
lines changed

1 file changed

+71
-41
lines changed

climada/util/calibrate/cross_calibrate.py

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
from typing import Optional, List, Mapping, Any, Tuple, Union, Sequence, Dict
2424
from copy import copy, deepcopy
2525
from pathlib import Path
26+
from itertools import repeat
2627

2728
import numpy as np
2829
from numpy.random import default_rng
2930
import pandas as pd
31+
from pathos.multiprocessing import ProcessPool
32+
from tqdm import tqdm
3033

3134
from ...engine.unsequa.input_var import InputVar
3235
from .base import Optimizer, Output, Input
@@ -46,6 +49,34 @@ def sample_data(data: pd.DataFrame, sample: List[Tuple[int, int]]):
4649
return data_sampled
4750

4851

52+
def event_info_from_input(input: Input) -> Dict[str, Any]:
53+
"""Get information on the event(s) for which we calibrated"""
54+
# Get region and event IDs
55+
data = input.data.dropna(axis="columns", how="all").dropna(axis="index", how="all")
56+
event_ids = data.index
57+
region_ids = data.columns
58+
59+
# Get event name
60+
event_names = input.hazard.select(event_id=event_ids.to_list()).event_name
61+
62+
# Return data
63+
return {
64+
"event_id": event_ids,
65+
"region_id": region_ids,
66+
"event_name": event_names,
67+
}
68+
69+
70+
def optimize(optimizer_type, input, opt_init_kwargs, opt_run_kwargs):
71+
opt = optimizer_type(input, **opt_init_kwargs)
72+
out = opt.run(**opt_run_kwargs)
73+
return SingleEnsembleOptimizerOutput(
74+
params=out.params,
75+
target=out.target,
76+
event_info=event_info_from_input(input),
77+
)
78+
79+
4980
@dataclass
5081
class SingleEnsembleOptimizerOutput(Output):
5182
"""Output for a single member of an ensemble optimizer
@@ -94,40 +125,34 @@ def to_input_var(self, impact_func_creator, **impfset_kwargs):
94125
]
95126
return InputVar.impfset(impf_set_list, **impfset_kwargs)
96127

97-
# Build MultiIndex DataFrame
98-
# data = pd.DataFrame(
99-
# columns=pd.MultiIndex.from_tuples(
100-
# [("Parameters", p_name) for p_name in outputs[0].params.keys()]
101-
# )
102-
# )
103-
104-
# Insert Parameters
105-
# params = pd.DataFrame.from_records([out.params for out in outputs])
106-
# for p_name in params.columns:
107-
# data["Parameters", p_name] = params[p_name]
108-
109-
# Insert
110-
111-
# return cls(data=pd.DataFrame.from_records([out.params for out in outputs]))
112-
113128

114129
@dataclass
115130
class EnsembleOptimizer(ABC):
116131
""""""
117132

118133
input: Input
119134
optimizer_type: Any
120-
optimizer_init_kwargs: Mapping[str, Any] = field(default_factory=dict)
135+
optimizer_init_kwargs: Dict[str, Any] = field(default_factory=dict)
121136
samples: List[List[Tuple[int, int]]] = field(init=False)
122137

123138
def __post_init__(self):
124139
""""""
125140
if self.samples is None:
126141
raise RuntimeError("Samples must be set!")
127142

128-
def run(self, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
143+
def run(self, processes=1, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
144+
if processes == 1:
145+
outputs = self._iterate_sequential(**optimizer_run_kwargs)
146+
else:
147+
outputs = self._iterate_parallel(processes, **optimizer_run_kwargs)
148+
return EnsembleOptimizerOutput.from_outputs(outputs)
149+
150+
def _iterate_sequential(
151+
self, **optimizer_run_kwargs
152+
) -> List[SingleEnsembleOptimizerOutput]:
153+
"""Iterate over all samples sequentially"""
129154
outputs = []
130-
for idx, sample in enumerate(self.samples):
155+
for idx, sample in enumerate(tqdm(self.samples)):
131156
input = self.input_from_sample(sample)
132157

133158
# Run optimizer
@@ -138,13 +163,37 @@ def run(self, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
138163
out = SingleEnsembleOptimizerOutput(
139164
params=out.params,
140165
target=out.target,
141-
event_info=self.event_info_from_input(input),
166+
event_info=event_info_from_input(input),
142167
)
143168

144-
print(f"Ensemble: {idx}, Params: {out.params}")
145169
outputs.append(out)
146170

147-
return EnsembleOptimizerOutput.from_outputs(outputs)
171+
return outputs
172+
173+
def _iterate_parallel(
174+
self, processes, **optimizer_run_kwargs
175+
) -> List[SingleEnsembleOptimizerOutput]:
176+
"""Iterate over all samples in parallel"""
177+
inputs = (self.input_from_sample(sample) for sample in self.samples)
178+
opt_init_kwargs = (
179+
self._update_init_kwargs(self.optimizer_init_kwargs, idx)
180+
for idx in range(len(self.samples))
181+
)
182+
183+
with ProcessPool(nodes=processes) as pool:
184+
return list(
185+
tqdm(
186+
pool.imap(
187+
optimize,
188+
repeat(self.optimizer_type),
189+
inputs,
190+
opt_init_kwargs,
191+
repeat(optimizer_run_kwargs),
192+
# chunksize=processes,
193+
),
194+
total=len(self.samples),
195+
)
196+
)
148197

149198
@abstractmethod
150199
def input_from_sample(self, sample: List[Tuple[int, int]]) -> Input:
@@ -159,25 +208,6 @@ def _update_init_kwargs(
159208
kwargs["random_state"] = kwargs["random_state"] + iteration
160209
return kwargs
161210

162-
def event_info_from_input(self, input: Input) -> Dict[str, Any]:
163-
"""Get information on the event(s) for which we calibrated"""
164-
# Get region and event IDs
165-
data = input.data.dropna(axis="columns", how="all").dropna(
166-
axis="index", how="all"
167-
)
168-
event_ids = data.index
169-
region_ids = data.columns
170-
171-
# Get event name
172-
event_names = input.hazard.select(event_id=event_ids.to_list()).event_name
173-
174-
# Return data
175-
return {
176-
"event_id": event_ids,
177-
"region_id": region_ids,
178-
"event_name": event_names,
179-
}
180-
181211

182212
@dataclass
183213
class AverageEnsembleOptimizer(EnsembleOptimizer):

0 commit comments

Comments
 (0)