Skip to content

Commit 2b5e4f5

Browse files
committed
Update cross-calibration module
* Add option to read and write EnsembleOptimizerOutput from/to HDF5. * Add plotting functions for EnsembleOptimizerOutput. * Streamline EnsembleOptimizer with generator methods.
1 parent 5505522 commit 2b5e4f5

File tree

1 file changed

+145
-36
lines changed

1 file changed

+145
-36
lines changed

climada/util/calibrate/cross_calibrate.py

Lines changed: 145 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,25 @@
2020

2121
from abc import ABC, abstractmethod
2222
from dataclasses import dataclass, InitVar, field
23-
from typing import Optional, List, Mapping, Any, Tuple, Union, Sequence, Dict
24-
from copy import copy, deepcopy
25-
from pathlib import Path
23+
from typing import List, Any, Tuple, Sequence, Dict, Callable
24+
from copy import copy
2625
from itertools import repeat
26+
import logging
2727

2828
import numpy as np
2929
from numpy.random import default_rng
3030
import pandas as pd
3131
from pathos.multiprocessing import ProcessPool
3232
from tqdm import tqdm
33+
import matplotlib.pyplot as plt
34+
import matplotlib.ticker as mticker
3335

3436
from ...engine.unsequa.input_var import InputVar
35-
from .base import Optimizer, Output, Input
37+
from ...entity.impact_funcs import ImpactFuncSet
38+
from ..coordinates import country_to_iso
39+
from .base import Output, Input
3640

37-
# TODO: derived classes for average and tragedy
41+
LOGGER = logging.getLogger(__name__)
3842

3943

4044
def sample_data(data: pd.DataFrame, sample: List[Tuple[int, int]]):
@@ -61,8 +65,8 @@ def event_info_from_input(input: Input) -> Dict[str, Any]:
6165

6266
# Return data
6367
return {
64-
"event_id": event_ids,
65-
"region_id": region_ids,
68+
"event_id": event_ids.to_numpy(),
69+
"region_id": region_ids.to_numpy(),
6670
"event_name": event_names,
6771
}
6872

@@ -109,21 +113,134 @@ def from_outputs(cls, outputs: Sequence[SingleEnsembleOptimizerOutput]):
109113

110114
return cls(data=data)
111115

116+
def to_hdf(self, filepath):
117+
"""Store data to HDF5"""
118+
self.data.to_hdf(filepath, key="data")
119+
120+
@classmethod
121+
def from_hdf(cls, filepath):
122+
"""Load data from HDF"""
123+
return cls(data=pd.read_hdf(filepath, key="data"))
124+
112125
@classmethod
113126
def from_csv(cls, filepath):
114127
"""Load data from CSV"""
128+
LOGGER.warning(
129+
"Do not use CSV for storage, because it does not preserve data types. "
130+
"Use HDF instead."
131+
)
115132
return cls(data=pd.read_csv(filepath, header=[0, 1]))
116133

117134
def to_csv(self, filepath):
118135
"""Store data as CSV"""
136+
LOGGER.warning(
137+
"Do not use CSV for storage, because it does not preserve data types. "
138+
"Use HDF instead."
139+
)
119140
self.data.to_csv(filepath, index=None)
120141

121-
def to_input_var(self, impact_func_creator, **impfset_kwargs):
122-
"""Build Unsequa InputVar from the parameters stored in this object"""
123-
impf_set_list = [
142+
def _to_impf_sets(self, impact_func_creator) -> List[ImpactFuncSet]:
143+
"""Return a list of impact functions created from the stored parameters"""
144+
return [
124145
impact_func_creator(**row["Parameters"]) for _, row in self.data.iterrows()
125146
]
126-
return InputVar.impfset(impf_set_list, **impfset_kwargs)
147+
148+
def to_input_var(
149+
self, impact_func_creator: Callable[..., ImpactFuncSet], **impfset_kwargs
150+
) -> InputVar:
151+
"""Build Unsequa InputVar from the parameters stored in this object"""
152+
return InputVar.impfset(
153+
self._to_impf_sets(impact_func_creator), **impfset_kwargs
154+
)
155+
156+
def plot(
157+
self, impact_func_creator: Callable[..., ImpactFuncSet], **impf_set_plot_kwargs
158+
):
159+
"""Plot all impact functions into the same plot"""
160+
impf_set_list = self._to_impf_sets(impact_func_creator)
161+
162+
# Create a single plot for the overall layout, then continue plotting into it
163+
axes = impf_set_list[0].plot(**impf_set_plot_kwargs)
164+
165+
# 'axes' might be array or single instance
166+
ax_first = axes
167+
if isinstance(axes, np.ndarray):
168+
ax_first = axes.flat[0]
169+
170+
# Legend is always the same
171+
handles, labels = ax_first.get_legend_handles_labels()
172+
173+
# Plot remaining impact function sets
174+
for impf_set in impf_set_list[1:]:
175+
impf_set.plot(axis=axes, **impf_set_plot_kwargs)
176+
177+
# Adjust legends
178+
for ax in np.asarray([axes]).flat:
179+
ax.legend(handles, labels)
180+
181+
return axes
182+
183+
def plot_shiny(
184+
self,
185+
impact_func_creator: Callable[..., ImpactFuncSet],
186+
haz_type,
187+
impf_id,
188+
):
189+
"""Plot all impact functions with appropriate color coding and event data"""
190+
# Store data to plot
191+
data_plt = []
192+
for _, row in self.data.iterrows():
193+
impf = impact_func_creator(**row["Parameters"]).get_func(
194+
haz_type=haz_type, fun_id=impf_id
195+
)
196+
region_id = country_to_iso(row[("Event", "region_id")])
197+
event_name = row[("Event", "event_name")]
198+
event_id = row[("Event", "event_id")]
199+
200+
label = f"{event_name}, {region_id}, {event_id}"
201+
if len(event_name) > 1 or len(event_id) > 1:
202+
label = label.replace("], [", "]\n[") # Multiline label
203+
204+
data_plt.append(
205+
{
206+
"intensity": impf.intensity,
207+
"mdr": impf.paa * impf.mdd,
208+
"label": label,
209+
}
210+
)
211+
212+
# Create plot
213+
_, ax = plt.subplots()
214+
colors = plt.get_cmap("turbo")(np.linspace(0, 1, self.data.shape[0]))
215+
216+
# Sort data by final MDR value, then plot
217+
data_plt = sorted(data_plt, key=lambda x: x["mdr"][-1], reverse=True)
218+
for idx, data_dict in enumerate(data_plt):
219+
ax.plot(
220+
data_dict["intensity"],
221+
data_dict["mdr"],
222+
label=data_dict["label"],
223+
color=colors[idx],
224+
)
225+
226+
# Cosmetics
227+
ax.set_xlabel(f"Intensity [{impf.intensity_unit}]")
228+
ax.set_ylabel("Impact")
229+
ax.set_title(f"{haz_type} {impf_id}")
230+
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1))
231+
ax.set_ylim(0, 1)
232+
ax.legend(
233+
bbox_to_anchor=(1.05, 1),
234+
borderaxespad=0,
235+
borderpad=0,
236+
loc="upper left",
237+
title="Event Name, Country, Event ID",
238+
frameon=False,
239+
fontsize="xx-small",
240+
title_fontsize="x-small",
241+
)
242+
243+
return ax
127244

128245

129246
@dataclass
@@ -159,47 +276,39 @@ def run(self, processes=1, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
159276
outputs = self._iterate_parallel(processes, **optimizer_run_kwargs)
160277
return EnsembleOptimizerOutput.from_outputs(outputs)
161278

279+
def _inputs(self):
280+
"""Generator for input objects"""
281+
for sample in self.samples:
282+
yield self.input_from_sample(sample)
283+
284+
def _opt_init_kwargs(self):
285+
"""Generator for optimizer initialization keyword arguments"""
286+
for idx in range(len(self.samples)):
287+
yield self._update_init_kwargs(self.optimizer_init_kwargs, idx)
288+
162289
def _iterate_sequential(
163290
self, **optimizer_run_kwargs
164291
) -> List[SingleEnsembleOptimizerOutput]:
165292
"""Iterate over all samples sequentially"""
166-
outputs = []
167-
for idx, sample in enumerate(tqdm(self.samples)):
168-
input = self.input_from_sample(sample)
169-
170-
# Run optimizer
171-
opt = self.optimizer_type(
172-
input, **self._update_init_kwargs(self.optimizer_init_kwargs, idx)
293+
return [
294+
optimize(self.optimizer_type, input, init_kwargs, optimizer_run_kwargs)
295+
for input, init_kwargs in tqdm(
296+
zip(self._inputs(), self._opt_init_kwargs()), total=len(self.samples)
173297
)
174-
out = opt.run(**optimizer_run_kwargs)
175-
out = SingleEnsembleOptimizerOutput(
176-
params=out.params,
177-
target=out.target,
178-
event_info=event_info_from_input(input),
179-
)
180-
181-
outputs.append(out)
182-
183-
return outputs
298+
]
184299

185300
def _iterate_parallel(
186301
self, processes, **optimizer_run_kwargs
187302
) -> List[SingleEnsembleOptimizerOutput]:
188303
"""Iterate over all samples in parallel"""
189-
inputs = (self.input_from_sample(sample) for sample in self.samples)
190-
opt_init_kwargs = (
191-
self._update_init_kwargs(self.optimizer_init_kwargs, idx)
192-
for idx in range(len(self.samples))
193-
)
194-
195304
with ProcessPool(nodes=processes) as pool:
196305
return list(
197306
tqdm(
198307
pool.imap(
199308
optimize,
200309
repeat(self.optimizer_type),
201-
inputs,
202-
opt_init_kwargs,
310+
self._inputs(),
311+
self._opt_init_kwargs(),
203312
repeat(optimizer_run_kwargs),
204313
# chunksize=processes,
205314
),

0 commit comments

Comments
 (0)