Skip to content

Commit 3f0dfc6

Browse files
committed
feat(interpolation): Vastly improves interpolation of metrics
1 parent d2b0f6b commit 3f0dfc6

File tree

3 files changed

+454
-211
lines changed

3 files changed

+454
-211
lines changed

climada/trajectories/interpolation.py

Lines changed: 124 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -21,121 +21,149 @@
2121
"""
2222

2323
import logging
24-
from abc import ABC, abstractmethod
24+
from abc import ABC
25+
from typing import Callable
2526

2627
import numpy as np
27-
from scipy.sparse import csr_matrix, lil_matrix
2828

2929
LOGGER = logging.getLogger(__name__)
3030

3131

32-
class InterpolationStrategy(ABC):
33-
"""Interface for interpolation strategies."""
34-
35-
@abstractmethod
36-
def interpolate(self, imp_E0, imp_E1, time_points: int) -> list: ...
37-
38-
39-
class LinearInterpolation(InterpolationStrategy):
40-
"""Linear interpolation strategy."""
41-
42-
def interpolate(self, imp_E0, imp_E1, time_points: int):
32+
def linear_interp_imp_mat(mat_start, mat_end, interpolation_range) -> list:
33+
"""Linearly interpolates between two impact matrices over an interpolation range.
34+
35+
Returns a list of `interpolation_range` matrices linearly interpolated between
36+
`mat_start` and `mat_end`.
37+
"""
38+
res = []
39+
for point in range(interpolation_range):
40+
ratio = point / (interpolation_range - 1)
41+
mat_interpolated = mat_start + ratio * (mat_end - mat_start)
42+
res.append(mat_interpolated)
43+
return res
44+
45+
46+
def exponential_interp_imp_mat(mat_start, mat_end, interpolation_range, rate) -> list:
47+
"""Exponentially interpolates between two impact matrices over an interpolation range with a growth rate `rate`.
48+
49+
Returns a list of `interpolation_range` matrices exponentially (with growth rate `rate`) interpolated between
50+
`mat_start` and `mat_end`.
51+
"""
52+
# Convert matrices to logarithmic domain
53+
mat_start = mat_start.copy()
54+
mat_end = mat_end.copy()
55+
mat_start.data = np.log(mat_start.data + np.finfo(float).eps) / np.log(rate)
56+
mat_end.data = np.log(mat_end.data + np.finfo(float).eps) / np.log(rate)
57+
58+
# Perform linear interpolation in the logarithmic domain
59+
res = []
60+
for point in range(interpolation_range):
61+
ratio = point / (interpolation_range - 1)
62+
mat_interpolated = mat_start * (1 - ratio) + ratio * mat_end
63+
mat_interpolated.data = np.exp(mat_interpolated.data * np.log(rate))
64+
res.append(mat_interpolated)
65+
return res
66+
67+
68+
def linear_interp_arrays(arr_start, arr_end, interpolation_range):
69+
"""Perform linear interpolation between two arrays (of a scalar metric) over an interpolation range."""
70+
prop1 = np.linspace(0, 1, interpolation_range)
71+
prop0 = 1 - prop1
72+
if arr_start.ndim > 1:
73+
prop0, prop1 = prop0.reshape(-1, 1), prop1.reshape(-1, 1)
74+
75+
return np.multiply(arr_start, prop0) + np.multiply(arr_end, prop1)
76+
77+
78+
def exponential_interp_arrays(arr_start, arr_end, interpolation_range, rate):
79+
"""Perform exponential interpolation between two arrays (of a scalar metric) over an interpolation range with a growth rate `rate`."""
80+
prop1 = np.linspace(0, 1, interpolation_range)
81+
prop0 = 1 - prop1
82+
if arr_start.ndim > 1:
83+
prop0, prop1 = prop0.reshape(-1, 1), prop1.reshape(-1, 1)
84+
85+
return np.exp(
86+
(
87+
np.multiply(np.log(arr_start) / np.log(rate), prop0)
88+
+ np.multiply(np.log(arr_end) / np.log(rate), prop1)
89+
)
90+
* np.log(rate)
91+
)
92+
93+
94+
def logarithmic_interp_arrays(arr_start, arr_end, interpolation_range):
95+
"""Perform logarithmic (natural logarithm) interpolation between two arrays (of a scalar metric) over an interpolation range."""
96+
prop1 = np.logspace(0, 1, interpolation_range)
97+
prop0 = 1 - prop1
98+
if arr_start.ndim > 1:
99+
prop0, prop1 = prop0.reshape(-1, 1), prop1.reshape(-1, 1)
100+
101+
return np.multiply(arr_start, prop0) + np.multiply(arr_end, prop1)
102+
103+
104+
class InterpolationStrategyBase(ABC):
105+
exposure_interp: Callable
106+
hazard_interp: Callable
107+
vulnerability_interp: Callable
108+
109+
def interp_exposure_dim(
110+
self, imp_E0, imp_E1, interpolation_range: int, **kwargs
111+
) -> list:
112+
"""Interpolates along the exposure change between two impact matrices.
113+
114+
Returns a list of `interpolation_range` matrices linearly interpolated between
115+
`mat_start` and `mat_end`.
116+
"""
43117
try:
44-
return self.interpolate_imp_mat(imp_E0, imp_E1, time_points)
45-
except ValueError as e:
46-
if str(e) == "inconsistent shapes":
118+
res = self.exposure_interp(imp_E0, imp_E1, interpolation_range, **kwargs)
119+
except ValueError as err:
120+
if str(err) == "inconsistent shapes":
47121
raise ValueError(
48-
"Interpolation between impact matrices of different shapes"
122+
"Tried to interpolate impact matrices of different shape. A possible reason could be Exposures of different shapes."
49123
)
50-
else:
51-
raise e
52-
53-
@staticmethod
54-
def interpolate_imp_mat(imp0, imp1, time_points):
55-
"""Interpolate between two impact matrices over a specified time range.
56-
57-
Parameters
58-
----------
59-
imp0 : ImpactCalc
60-
The impact calculation for the starting time.
61-
imp1 : ImpactCalc
62-
The impact calculation for the ending time.
63-
time_points:
64-
The number of points to interpolate.
65-
66-
Returns
67-
-------
68-
list of np.ndarray
69-
List of interpolated impact matrices for each time points in the specified range.
70-
"""
71124

72-
def interpolate_sm(mat_start, mat_end, time, time_points):
73-
"""Perform linear interpolation between two matrices for a specified time point."""
74-
if time > time_points:
75-
raise ValueError("time point must be within the range")
125+
raise err
76126

77-
ratio = time / (time_points - 1)
127+
return res
78128

79-
# Convert the input matrices to a format that allows efficient modification of its elements
80-
mat_start = lil_matrix(mat_start)
81-
mat_end = lil_matrix(mat_end)
129+
def interp_hazard_dim(
130+
self, metric_0, metric_1, interpolation_range: int, **kwargs
131+
) -> np.ndarray:
132+
return self.hazard_interp(metric_0, metric_1, interpolation_range, **kwargs)
82133

83-
# Perform the linear interpolation
84-
mat_interpolated = mat_start + ratio * (mat_end - mat_start)
134+
def interp_vulnerability_dim(
135+
self, metric_0, metric_1, interpolation_range: int, **kwargs
136+
) -> np.ndarray:
137+
return self.vulnerability_interp(
138+
metric_0, metric_1, interpolation_range, **kwargs
139+
)
85140

86-
return csr_matrix(mat_interpolated)
87-
88-
LOGGER.debug(f"imp0: {imp0.imp_mat.data[0]}, imp1: {imp1.imp_mat.data[0]}")
89-
return [
90-
interpolate_sm(imp0.imp_mat, imp1.imp_mat, time, time_points)
91-
for time in range(time_points)
92-
]
93141

142+
class InterpolationStrategy(InterpolationStrategyBase):
143+
"""Interface for interpolation strategies."""
94144

95-
class ExponentialInterpolation(InterpolationStrategy):
96-
"""Exponential interpolation strategy."""
145+
def __init__(self, exposure_interp, hazard_interp, vulnerability_interp) -> None:
146+
super().__init__()
147+
self.exposure_interp = exposure_interp
148+
self.hazard_interp = hazard_interp
149+
self.vulnerability_interp = vulnerability_interp
97150

98-
def interpolate(self, imp_E0, imp_E1, time_points: int):
99-
return self.interpolate_imp_mat(imp_E0, imp_E1, time_points)
100-
101-
@staticmethod
102-
def interpolate_imp_mat(imp0, imp1, time_points):
103-
"""Interpolate between two impact matrices over a specified time range.
104-
105-
Parameters
106-
----------
107-
imp0 : ImpactCalc
108-
The impact calculation for the starting time.
109-
imp1 : ImpactCalc
110-
The impact calculation for the ending time.
111-
time_points:
112-
The number of points to interpolate.
113-
114-
Returns
115-
-------
116-
list of np.ndarray
117-
List of interpolated impact matrices for each time points in the specified range.
118-
"""
119151

120-
def interpolate_sm(mat_start, mat_end, time, time_points):
121-
"""Perform exponential interpolation between two matrices for a specified time point."""
122-
if time > time_points:
123-
raise ValueError("time point must be within the range")
124-
125-
# Convert matrices to logarithmic domain
126-
log_mat_start = np.log(mat_start.toarray() + np.finfo(float).eps)
127-
log_mat_end = np.log(mat_end.toarray() + np.finfo(float).eps)
152+
class AllLinearStrategy(InterpolationStrategyBase):
153+
"""Linear interpolation strategy."""
128154

129-
# Perform linear interpolation in the logarithmic domain
130-
ratio = time / (time_points - 1)
131-
log_mat_interpolated = log_mat_start + ratio * (log_mat_end - log_mat_start)
155+
def __init__(self) -> None:
156+
super().__init__()
157+
self.exposure_interp = linear_interp_imp_mat
158+
self.hazard_interp = linear_interp_arrays
159+
self.vulnerability_interp = linear_interp_arrays
132160

133-
# Convert back to the original domain using the exponential function
134-
mat_interpolated = np.exp(log_mat_interpolated)
135161

136-
return csr_matrix(mat_interpolated)
162+
class ExponentialExposureInterpolation(InterpolationStrategyBase):
163+
"""Exponential interpolation strategy."""
137164

138-
return [
139-
interpolate_sm(imp0.imp_mat, imp1.imp_mat, time, time_points)
140-
for time in range(time_points)
141-
]
165+
def __init__(self) -> None:
166+
super().__init__()
167+
self.exposure_interp = exponential_interp_imp_mat
168+
self.hazard_interp = linear_interp_arrays
169+
self.vulnerability_interp = linear_interp_arrays

climada/trajectories/risk_trajectory.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
import pandas as pd
3030

3131
from climada.entity.disc_rates.base import DiscRates
32+
from climada.trajectories.interpolation import InterpolationStrategyBase
3233
from climada.trajectories.riskperiod import (
34+
AllLinearStrategy,
3335
CalcRiskPeriod,
3436
ImpactCalcComputation,
3537
ImpactComputationStrategy,
36-
InterpolationStrategy,
37-
LinearInterpolation,
3838
)
3939
from climada.trajectories.snapshot import Snapshot
4040

@@ -79,7 +79,7 @@ def __init__(
7979
risk_transf_cover=None,
8080
risk_transf_attach=None,
8181
calc_residual: bool = True,
82-
interpolation_strategy: InterpolationStrategy | None = None,
82+
interpolation_strategy: InterpolationStrategyBase | None = None,
8383
impact_computation_strategy: ImpactComputationStrategy | None = None,
8484
):
8585
self._reset_metrics()
@@ -94,7 +94,7 @@ def __init__(
9494
self._risk_transf_cover = risk_transf_cover
9595
self._risk_transf_attach = risk_transf_attach
9696
self._calc_residual = calc_residual
97-
self._interpolation_strategy = interpolation_strategy or LinearInterpolation()
97+
self._interpolation_strategy = interpolation_strategy or AllLinearStrategy()
9898
self._impact_computation_strategy = (
9999
impact_computation_strategy or ImpactCalcComputation()
100100
)
@@ -283,6 +283,9 @@ def _generic_metrics(
283283
raise e
284284
else:
285285
tmp = tmp.set_index(["date", "group", "measure", "metric"])
286+
if "coord_id" in tmp.columns:
287+
tmp = tmp.set_index(["coord_id"], append=True)
288+
286289
tmp = tmp[
287290
~tmp.index.duplicated(keep="last")
288291
] # We want to avoid overlap when more than 2 snapshots
@@ -596,6 +599,15 @@ def plot_per_date_waterfall(
596599
risk_component = self._calc_waterfall_plot_data(
597600
start_date=start_date, end_date=end_date
598601
)
602+
risk_component = risk_component[
603+
[
604+
"base risk",
605+
"exposure contribution",
606+
"hazard contribution",
607+
"vulnerability contribution",
608+
"interaction contribution",
609+
]
610+
]
599611
risk_component.plot(ax=ax, kind="bar", stacked=True)
600612
# Construct y-axis label and title based on parameters
601613
value_label = "USD"
@@ -654,21 +666,30 @@ def plot_waterfall(
654666
labels = [
655667
f"Risk {start_date}",
656668
f"Exposure {end_date}",
657-
f"Hazard {end_date}¹",
669+
f"Hazard {end_date}",
670+
f"Vulnerability {end_date}",
671+
f"Interaction {end_date}",
658672
f"Total Risk {end_date}",
659673
]
660674
values = [
661675
risk_component["base risk"],
662-
risk_component["delta from exposure"],
663-
risk_component["delta from hazard"],
664-
risk_component["base risk"]
665-
+ risk_component["delta from exposure"]
666-
+ risk_component["delta from hazard"],
676+
risk_component["exposure contribution"],
677+
risk_component["hazard contribution"],
678+
risk_component["vulnerability contribution"],
679+
risk_component["interaction contribution"],
680+
risk_component.sum(),
667681
]
668682
bottoms = [
669683
0.0,
670684
risk_component["base risk"],
671-
risk_component["base risk"] + risk_component["delta from exposure"],
685+
risk_component["base risk"] + risk_component["exposure contribution"],
686+
risk_component["base risk"]
687+
+ risk_component["exposure contribution"]
688+
+ risk_component["hazard contribution"],
689+
risk_component["base risk"]
690+
+ risk_component["exposure contribution"]
691+
+ risk_component["hazard contribution"]
692+
+ risk_component["vulnerability contribution"],
672693
0.0,
673694
]
674695

@@ -677,7 +698,14 @@ def plot_waterfall(
677698
values,
678699
bottom=bottoms,
679700
edgecolor="black",
680-
color=["tab:blue", "tab:orange", "tab:green", "tab:red"],
701+
color=[
702+
"tab:cyan",
703+
"tab:orange",
704+
"tab:green",
705+
"tab:red",
706+
"tab:purple",
707+
"tab:blue",
708+
],
681709
)
682710
for i in range(len(values)):
683711
ax.text(
@@ -695,16 +723,19 @@ def plot_waterfall(
695723

696724
ax.set_title(title_label)
697725
ax.set_ylabel(value_label)
698-
# ax.tick_params(axis='x', labelrotation=90,)
699-
ax.annotate(
700-
"""¹: The increase in risk due to hazard denotes the difference in risk with future exposure
701-
and hazard compared to risk with future exposure and present hazard.""",
702-
xy=(0.0, -0.15),
703-
xycoords="axes fraction",
704-
ha="left",
705-
va="center",
706-
fontsize=8,
726+
ax.tick_params(
727+
axis="x",
728+
labelrotation=90,
707729
)
730+
# ax.annotate(
731+
# """¹: The increase in risk due to hazard denotes the difference in risk with future exposure
732+
# and hazard compared to risk with future exposure and present hazard.""",
733+
# xy=(0.0, -0.15),
734+
# xycoords="axes fraction",
735+
# ha="left",
736+
# va="center",
737+
# fontsize=8,
738+
# )
708739

709740
return ax
710741

0 commit comments

Comments
 (0)