Skip to content

Commit 0cfe35c

Browse files
update interpolate util module to add constant extrapolation option
1 parent 5d11f75 commit 0cfe35c

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

climada/util/interpolation.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def interpolate_ev(
3737
logy = False,
3838
x_threshold = None,
3939
y_threshold = None,
40-
extrapolation = False,
40+
extrapolation = None,
4141
y_asymptotic = np.nan
4242
):
4343
"""
@@ -60,14 +60,16 @@ def interpolate_ev(
6060
Lower threshold to filter x_train. Defaults to None.
6161
y_threshold : float, optional
6262
Lower threshold to filter y_train. Defaults to None.
63-
extrapolation : bool, optional
64-
If set to True, values will be extrapolated. If set to False, x_test values
65-
smaller than x_train will be assigned y_train[0] (x_train must be sorted in
66-
ascending order), and x_test values larger than x_train will be assigned
67-
y_asymptotic. Defaults to False
63+
extrapolation : str, optional
64+
If set to 'extrapolate', values will be extrapolated. If set to 'extrapolate_constant',
65+
x_test values smaller than x_train will be assigned y_train[0] (x_train must be sorted
66+
in ascending order), and x_test values larger than x_train will be assigned
67+
y_asymptotic. If set to None, x_test values outside of the range of x_train will be
68+
assigned np.nan. Defaults to None.
6869
y_asymptotic : float, optional
69-
Return value and if extrapolation is True or x_train.size < 2, for x_test
70-
values larger than x_train. Defaults to np.nan.
70+
Has no effect if extrapolation is None. Else, provides return value and if
71+
for x_test values larger than x_train, for x_train.size < 2 or if extrapolation is set
72+
to 'extrapolate_constant'. Defaults to np.nan.
7173
7274
Returns
7375
-------
@@ -80,20 +82,27 @@ def interpolate_ev(
8082
x_test, x_train, y_train, logx, logy, x_threshold, y_threshold
8183
)
8284

83-
# handle case of small training data sizes
85+
# handle case of small training data sizes
8486
if x_train.size < 2:
87+
if not extrapolation:
88+
return np.full_like(x_test, np.nan)
89+
else:
90+
LOGGER.warning('Data is being extrapolated.')
91+
return _interpolate_small_input(x_test, x_train, y_train, logy, y_asymptotic)
92+
93+
# warn if values are being extrapolated
94+
if extrapolation and (np.min(x_test) < np.min(x_train) or np.max(x_test) > np.max(x_train)):
8595
LOGGER.warning('Data is being extrapolated.')
86-
return _interpolate_small_input(x_test, x_train, y_train, logy, y_asymptotic)
8796

8897
# calculate fill values
89-
if extrapolation:
98+
if extrapolation == 'extrapolate':
9099
fill_value = 'extrapolate'
91-
if np.min(x_test) < np.min(x_train) or np.max(x_test) > np.max(x_train):
92-
LOGGER.warning('Data is being extrapolated.')
93-
else:
100+
elif extrapolation == 'extrapolate_constant':
94101
if not all(sorted(x_train) == x_train):
95102
raise ValueError('x_train array must be sorted in ascending order.')
96103
fill_value = (y_train[0], np.log10(y_asymptotic) if logy else y_asymptotic)
104+
else:
105+
fill_value = np.nan
97106

98107
interpolation = interpolate.interp1d(
99108
x_train, y_train, fill_value=fill_value, bounds_error=False)
@@ -142,9 +151,16 @@ def stepfunction_ev(
142151
x_test, x_train, y_train, None, None, x_threshold, y_threshold
143152
)
144153

154+
155+
145156
# handle case of small training data sizes
146157
if x_train.size < 2:
158+
LOGGER.warning('Data is being extrapolated.')
147159
return _interpolate_small_input(x_test, x_train, y_train, None, y_asymptotic)
160+
161+
# warn if values are being extrapolated
162+
if (np.min(x_test) < np.min(x_train) or np.max(x_test) > np.max(x_train)):
163+
LOGGER.warning('Data is being extrapolated.')
148164

149165
# find indices of x_test if sorted into x_train
150166
if not all(sorted(x_train) == x_train):

climada/util/test/test_interpolation.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@ def test_interpolate_ev_linear_interp(self):
3535
x_test = np.array([0., 3., 4., 6.])
3636
np.testing.assert_allclose(
3737
interpolate_ev(x_test, x_train, y_train),
38+
np.array([np.nan, 4., 3., np.nan])
39+
)
40+
np.testing.assert_allclose(
41+
interpolate_ev(x_test, x_train, y_train, extrapolation='extrapolate_constant'),
3842
np.array([8., 4., 3., np.nan])
3943
)
4044
np.testing.assert_allclose(
41-
interpolate_ev(x_test, x_train, y_train, y_asymptotic = 0),
45+
interpolate_ev(x_test, x_train, y_train,
46+
extrapolation='extrapolate_constant', y_asymptotic = 0),
4247
np.array([8., 4., 3., 0.])
4348
)
4449

@@ -48,15 +53,17 @@ def test_interpolate_ev_threshold_parameters(self):
4853
y_train = np.array([4., 1., 4.])
4954
x_test = np.array([-1., 3., 4.])
5055
np.testing.assert_allclose(
51-
interpolate_ev(x_test, x_train, y_train),
56+
interpolate_ev(x_test, x_train, y_train, extrapolation='extrapolate_constant'),
5257
np.array([4., 1., 2.])
5358
)
5459
np.testing.assert_allclose(
55-
interpolate_ev(x_test, x_train, y_train, x_threshold=1.),
60+
interpolate_ev(x_test, x_train, y_train, x_threshold=1.,
61+
extrapolation='extrapolate_constant'),
5662
np.array([1., 1., 2.])
5763
)
5864
np.testing.assert_allclose(
59-
interpolate_ev(x_test, x_train, y_train, y_threshold=2.),
65+
interpolate_ev(x_test, x_train, y_train, y_threshold=2.,
66+
extrapolation='extrapolate_constant'),
6067
np.array([4., 4., 4.])
6168
)
6269

@@ -66,25 +73,27 @@ def test_interpolate_ev_scale_parameters(self):
6673
y_train = np.array([1., 3.])
6774
x_test = np.array([1e0, 1e2])
6875
np.testing.assert_allclose(
69-
interpolate_ev(x_test, x_train, y_train, logx=True, extrapolation=True),
76+
interpolate_ev(x_test, x_train, y_train, logx=True, extrapolation='extrapolate'),
7077
np.array([0., 2.])
7178
)
7279
np.testing.assert_allclose(
73-
interpolate_ev(x_test, x_train, y_train, logx=True),
80+
interpolate_ev(x_test, x_train, y_train, logx=True,
81+
extrapolation='extrapolate_constant'),
7482
np.array([1., 2.])
7583
)
7684
x_train = np.array([1., 3.])
7785
y_train = np.array([1e1, 1e3])
7886
x_test = np.array([0., 2.])
7987
np.testing.assert_allclose(
80-
interpolate_ev(x_test, x_train, y_train, logy=True, extrapolation=True),
88+
interpolate_ev(x_test, x_train, y_train, logy=True, extrapolation='extrapolate'),
8189
np.array([1e0, 1e2])
8290
)
8391
x_train = np.array([1e1, 1e3])
8492
y_train = np.array([1e1, 1e5])
8593
x_test = np.array([1e0, 1e2])
8694
np.testing.assert_allclose(
87-
interpolate_ev(x_test, x_train, y_train, logx=True, logy=True, extrapolation=True),
95+
interpolate_ev(x_test, x_train, y_train, logx=True, logy=True,
96+
extrapolation='extrapolate'),
8897
np.array([1e-1, 1e3])
8998
)
9099

@@ -95,7 +104,7 @@ def test_interpolate_ev_degenerate_input(self):
95104
y_train = np.zeros(3)
96105
np.testing.assert_allclose(
97106
interpolate_ev(x_test, x_train, y_train),
98-
np.array([0., 0., 0.])
107+
np.array([np.nan, 0., 0.])
99108
)
100109

101110
def test_interpolate_ev_small_input(self):
@@ -104,13 +113,18 @@ def test_interpolate_ev_small_input(self):
104113
y_train = np.array([2.])
105114
x_test = np.array([0., 1., 2.])
106115
np.testing.assert_allclose(
107-
interpolate_ev(x_test, x_train, y_train),
116+
interpolate_ev(x_test, x_train, y_train, extrapolation='extrapolate'),
108117
np.array([2., 2., np.nan])
109118
)
110119
np.testing.assert_allclose(
111-
interpolate_ev(x_test, x_train, y_train, y_asymptotic=0),
120+
interpolate_ev(x_test, x_train, y_train, extrapolation='extrapolate', y_asymptotic=0),
112121
np.array([2., 2., 0.])
113122
)
123+
np.testing.assert_allclose(
124+
interpolate_ev(x_test, x_train, y_train),
125+
np.full(3, np.nan)
126+
)
127+
114128
x_train = np.array([])
115129
y_train = np.array([])
116130
x_test = np.array([0., 1., 2.])
@@ -119,7 +133,8 @@ def test_interpolate_ev_small_input(self):
119133
np.full(3, np.nan)
120134
)
121135
np.testing.assert_allclose(
122-
interpolate_ev(x_test, x_train, y_train, y_asymptotic=0),
136+
interpolate_ev(x_test, x_train, y_train,
137+
extrapolation='extrapolate_constant', y_asymptotic=0),
123138
np.zeros(3)
124139
)
125140

0 commit comments

Comments
 (0)