Skip to content

Commit 0055c64

Browse files
committed
Deepcopy optimizer run kwargs and add hazard intensity histogram to shiny plot
1 parent b5faf07 commit 0055c64

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

climada/util/calibrate/cross_calibrate.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from abc import ABC, abstractmethod
2222
from dataclasses import dataclass, InitVar, field
2323
from typing import List, Any, Tuple, Sequence, Dict, Callable
24-
from copy import copy
24+
from copy import copy, deepcopy
2525
from itertools import repeat
2626
import logging
2727

@@ -188,6 +188,7 @@ def plot_shiny(
188188
impact_func_creator: Callable[..., ImpactFuncSet],
189189
haz_type,
190190
impf_id,
191+
input=None,
191192
):
192193
"""Plot all impact functions with appropriate color coding and event data"""
193194
# Store data to plot
@@ -214,9 +215,15 @@ def plot_shiny(
214215

215216
# Create plot
216217
_, ax = plt.subplots()
217-
colors = plt.get_cmap("turbo")(np.linspace(0, 1, self.data.shape[0]))
218+
219+
# Plot hazard histogram
220+
# NOTE: Actually requires selection by exposure, but this is not trivial!
221+
if input is not None:
222+
ax2 = ax.twinx()
223+
ax2.hist(input.hazard.intensity.data, bins=40, color="grey", alpha=0.5)
218224

219225
# Sort data by final MDR value, then plot
226+
colors = plt.get_cmap("turbo")(np.linspace(0, 1, self.data.shape[0]))
220227
data_plt = sorted(data_plt, key=lambda x: x["mdr"][-1], reverse=True)
221228
for idx, data_dict in enumerate(data_plt):
222229
ax.plot(
@@ -294,7 +301,9 @@ def _iterate_sequential(
294301
) -> List[SingleEnsembleOptimizerOutput]:
295302
"""Iterate over all samples sequentially"""
296303
return [
297-
optimize(self.optimizer_type, input, init_kwargs, optimizer_run_kwargs)
304+
optimize(
305+
self.optimizer_type, input, init_kwargs, deepcopy(optimizer_run_kwargs)
306+
)
298307
for input, init_kwargs in tqdm(
299308
zip(self._inputs(), self._opt_init_kwargs()), total=len(self.samples)
300309
)
@@ -304,6 +313,8 @@ def _iterate_parallel(
304313
self, processes, **optimizer_run_kwargs
305314
) -> List[SingleEnsembleOptimizerOutput]:
306315
"""Iterate over all samples in parallel"""
316+
iterations = len(self.samples)
317+
opt_run_kwargs = (deepcopy(optimizer_run_kwargs) for _ in range(iterations))
307318
with ProcessPool(nodes=processes) as pool:
308319
return list(
309320
tqdm(
@@ -312,10 +323,10 @@ def _iterate_parallel(
312323
repeat(self.optimizer_type),
313324
self._inputs(),
314325
self._opt_init_kwargs(),
315-
repeat(optimizer_run_kwargs),
326+
opt_run_kwargs,
316327
# chunksize=processes,
317328
),
318-
total=len(self.samples),
329+
total=iterations,
319330
)
320331
)
321332

0 commit comments

Comments
 (0)