Skip to content

Commit 4ef643f

Browse files
committed
fix(plot, groups): improves plot, fixes group in eai
1 parent 304a641 commit 4ef643f

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

climada/trajectories/risk_trajectory.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import itertools
2626
import logging
2727

28+
import matplotlib.dates as mdates
2829
import matplotlib.pyplot as plt
2930
import pandas as pd
3031

@@ -344,9 +345,10 @@ def eai_metrics(self, npv: bool = True, **kwargs):
344345
This computation may become quite expensive for big areas with high resolution.
345346
346347
"""
347-
return self._compute_metrics(
348+
df = self._compute_metrics(
348349
npv=npv, metric_name="eai", metric_meth="calc_eai_gdf", **kwargs
349350
)
351+
return df
350352

351353
def aai_metrics(self, npv: bool = True, **kwargs):
352354
"""Return the average annual impacts for each date.
@@ -564,6 +566,7 @@ def plot_per_date_waterfall(
564566
ax=None,
565567
start_date: datetime.date | None = None,
566568
end_date: datetime.date | None = None,
569+
figsize=(12, 6),
567570
):
568571
"""Plot a waterfall chart of risk components over a specified date range.
569572
@@ -593,7 +596,9 @@ def plot_per_date_waterfall(
593596
compared to the risk associated with future exposure and present hazard.
594597
"""
595598
if ax is None:
596-
_, ax = plt.subplots(figsize=(12, 6))
599+
fig, ax = plt.subplots(figsize=figsize)
600+
else:
601+
fig = ax.figure # get parent figure from the axis
597602
start_date = self.start_date if start_date is None else start_date
598603
end_date = self.end_date if end_date is None else end_date
599604
risk_component = self._calc_waterfall_plot_data(
@@ -608,16 +613,30 @@ def plot_per_date_waterfall(
608613
"interaction contribution",
609614
]
610615
]
611-
risk_component.plot(ax=ax, kind="bar", stacked=True)
616+
# risk_component.plot(x="date", ax=ax, kind="bar", stacked=True)
617+
ax.stackplot(
618+
risk_component.index,
619+
[risk_component[col] for col in risk_component.columns],
620+
labels=risk_component.columns,
621+
)
622+
ax.legend()
623+
# bottom = [0] * len(risk_component)
624+
# for col in risk_component.columns:
625+
# bottom = [b + v for b, v in zip(bottom, risk_component[col])]
612626
# Construct y-axis label and title based on parameters
613627
value_label = "USD"
614-
title_label = (
615-
f"Risk between {start_date} and {end_date} (Annual Average impact)"
616-
)
628+
title_label = f"Risk between {start_date} and {end_date} (Average impact)"
629+
630+
locator = mdates.AutoDateLocator()
631+
formatter = mdates.ConciseDateFormatter(locator)
632+
633+
ax.xaxis.set_major_locator(locator)
634+
ax.xaxis.set_major_formatter(formatter)
617635

618636
ax.set_title(title_label)
619637
ax.set_ylabel(value_label)
620-
return ax
638+
ax.set_ylim(0.0, 1.1 * ax.get_ylim()[1])
639+
return fig, ax
621640

622641
def plot_waterfall(
623642
self,
@@ -719,10 +738,11 @@ def plot_waterfall(
719738

720739
# Construct y-axis label and title based on parameters
721740
value_label = "USD"
722-
title_label = f"Risk at {start_date} and {end_date} (Annual Average impact)"
741+
title_label = f"Risk at {start_date} and {end_date} (Average impact)"
723742

724743
ax.set_title(title_label)
725744
ax.set_ylabel(value_label)
745+
ax.set_ylim(0.0, 1.1 * ax.get_ylim()[1])
726746
ax.tick_params(
727747
axis="x",
728748
labelrotation=90,

climada/trajectories/riskperiod.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -685,12 +685,10 @@ def calc_eai_gdf(self):
685685
df = df.reset_index().melt(
686686
id_vars="date", var_name="coord_id", value_name="risk"
687687
)
688-
eai_gdf = self.snapshot0.exposure.gdf
688+
eai_gdf = self.snapshot0.exposure.gdf[["group_id"]]
689689
eai_gdf["coord_id"] = eai_gdf.index
690690
eai_gdf = eai_gdf.merge(df, on="coord_id")
691-
eai_gdf = eai_gdf.rename(
692-
columns={"group_id": "group", "value": "exposure_value"}
693-
)
691+
eai_gdf = eai_gdf.rename(columns={"group_id": "group"})
694692
eai_gdf["metric"] = "eai"
695693
eai_gdf["measure"] = self.measure.name if self.measure else "no_measure"
696694
return eai_gdf

0 commit comments

Comments
 (0)