Skip to content

Commit 25d89b0

Browse files
committed
Improve plots, add 'replace' parameter to AverageEnsembleOptimizer
1 parent e9cf7dd commit 25d89b0

File tree

2 files changed

+65
-19
lines changed

2 files changed

+65
-19
lines changed

climada/util/calibrate/ensemble.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def plot_shiny(
257257
inp: Input | None = None,
258258
impf_plot_kwargs: Mapping[str, Any] | None = None,
259259
hazard_plot_kwargs: Mapping[str, Any] | None = None,
260+
legend: bool = True,
260261
):
261262
"""Plot all impact functions with appropriate color coding and event data"""
262263
# Store data to plot
@@ -271,9 +272,15 @@ def plot_shiny(
271272
f"impf_id: {impf_id}"
272273
)
273274

274-
region_id = country_to_iso(row[("Event", "region_id")])
275-
event_name = row[("Event", "event_name")]
276-
event_id = row[("Event", "event_id")]
275+
def single_entry(arr):
276+
"""If ``arr`` has a single entry, return it, else ``arr`` itself"""
277+
if len(arr) == 1:
278+
return arr[0]
279+
return arr
280+
281+
region_id = single_entry(country_to_iso(row[("Event", "region_id")]))
282+
event_name = single_entry(row[("Event", "event_name")])
283+
event_id = single_entry(row[("Event", "event_id")])
277284

278285
label = f"{event_name}, {region_id}, {event_id}"
279286
if len(event_name) > 1 or len(event_id) > 1:
@@ -289,15 +296,26 @@ def plot_shiny(
289296

290297
# Create plot
291298
_, ax = plt.subplots()
299+
legend_xpad = 0
292300

293301
# Plot hazard histogram
294302
# NOTE: Actually requires selection by exposure, but this is not trivial!
295303
if inp is not None:
304+
# Create secondary axis
296305
ax2 = ax.twinx()
306+
ax2.set_zorder(1)
307+
ax.set_zorder(2)
308+
ax.set_facecolor("none")
309+
legend_xpad = 0.15
310+
311+
# Draw histogram
297312
hist_kwargs = {"bins": 40, "color": "grey", "alpha": 0.5}
298313
if hazard_plot_kwargs is not None:
299314
hist_kwargs.update(hazard_plot_kwargs)
300315
ax2.hist(inp.hazard.intensity.data, **hist_kwargs)
316+
ax2.set_ylabel("Intensity Count", color=hist_kwargs["color"])
317+
ax2.tick_params(axis="y", colors=hist_kwargs["color"])
318+
301319
elif hazard_plot_kwargs is not None:
302320
LOGGER.warning("No 'inp' parameter provided. Ignoring 'hazard_plot_kwargs'")
303321

@@ -320,16 +338,17 @@ def plot_shiny(
320338
ax.set_title(f"{haz_type} {impf_id}")
321339
ax.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1))
322340
ax.set_ylim(0, 1)
323-
ax.legend(
324-
bbox_to_anchor=(1.05, 1),
325-
borderaxespad=0,
326-
borderpad=0,
327-
loc="upper left",
328-
title="Event Name, Country, Event ID",
329-
frameon=False,
330-
fontsize="xx-small",
331-
title_fontsize="x-small",
332-
)
341+
if legend:
342+
ax.legend(
343+
bbox_to_anchor=(1.05 + legend_xpad, 1),
344+
borderaxespad=0,
345+
borderpad=0,
346+
loc="upper left",
347+
title="Event Name, Country, Event ID",
348+
frameon=False,
349+
fontsize="xx-small",
350+
title_fontsize="x-small",
351+
)
333352

334353
return ax
335354

@@ -511,16 +530,21 @@ class AverageEnsembleOptimizer(EnsembleOptimizer):
511530
The number of calibration tasks to perform
512531
random_state : int
513532
The seed for the random number generator selecting the samples
533+
replace : bool
534+
If samples of the impact data should be drawn with replacement
514535
"""
515536

516537
sample_fraction: InitVar[float] = 0.8
517538
ensemble_size: InitVar[int] = 20
518539
random_state: InitVar[int] = 1
540+
replace: InitVar[bool] = False
519541

520-
def __post_init__(self, sample_fraction, ensemble_size, random_state):
542+
def __post_init__(self, sample_fraction, ensemble_size, random_state, replace):
521543
"""Create the samples"""
522-
if sample_fraction <= 0 or sample_fraction >= 1:
523-
raise ValueError("Sample fraction must be in (0, 1)")
544+
if sample_fraction <= 0:
545+
raise ValueError("Sample fraction must be larger than 0")
546+
elif sample_fraction > 1 and not replace:
547+
raise ValueError("Sample fraction must be <=1 or replace must be True")
524548
if ensemble_size < 1:
525549
raise ValueError("Ensemble size must be >=1")
526550

@@ -532,7 +556,7 @@ def __post_init__(self, sample_fraction, ensemble_size, random_state):
532556
# Create samples
533557
rng = default_rng(random_state)
534558
self.samples = [
535-
rng.choice(notna_idx, size=num_samples, replace=False)
559+
rng.choice(notna_idx, size=num_samples, replace=replace)
536560
for _ in range(ensemble_size)
537561
]
538562

climada/util/calibrate/test/test_ensemble.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,20 +356,42 @@ def test_post_init_sampling(self):
356356
npt.assert_array_equal(samples[0], [[0, 0], [2, 1], [3, 0], [3, 1]])
357357
npt.assert_array_equal(samples[0], samples[1])
358358

359+
def test_sampling_replace(self):
360+
"""Test if replacement works"""
361+
data = pd.DataFrame({"a": [1.0]})
362+
self.input = DummyInput(data)
363+
opt = AverageEnsembleOptimizer(
364+
input=self.input,
365+
ensemble_size=1,
366+
sample_fraction=3,
367+
replace=True,
368+
optimizer_type=ConcreteOptimizer,
369+
)
370+
npt.assert_array_equal(opt.samples, [[(0, 0), (0, 0), (0, 0)]])
371+
359372
def test_invalid_sample_fraction(self):
360373
with self.assertRaisesRegex(ValueError, "Sample fraction"):
361374
AverageEnsembleOptimizer(
362375
input=self.input,
363-
sample_fraction=1,
376+
sample_fraction=0,
364377
optimizer_type=ConcreteOptimizer,
365378
)
366379
with self.assertRaisesRegex(ValueError, "Sample fraction"):
367380
AverageEnsembleOptimizer(
368381
input=self.input,
369-
sample_fraction=0,
382+
sample_fraction=1.2,
383+
replace=False,
370384
optimizer_type=ConcreteOptimizer,
371385
)
372386

387+
# Should not throw
388+
AverageEnsembleOptimizer(
389+
input=self.input,
390+
sample_fraction=1.1,
391+
replace=True,
392+
optimizer_type=ConcreteOptimizer,
393+
)
394+
373395
def test_invalid_ensemble_size(self):
374396
with self.assertRaisesRegex(ValueError, "Ensemble size must be >=1"):
375397
AverageEnsembleOptimizer(

0 commit comments

Comments
 (0)