Skip to content

Commit 06fa995

Browse files
changed scale parameter from str to bool
1 parent 82cb0f6 commit 06fa995

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

climada/util/interpolation.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def interpolate_ev(
3333
x_test,
3434
x_train,
3535
y_train,
36-
x_scale = None,
37-
y_scale = None,
36+
logx = False,
37+
logy = False,
3838
x_threshold = None,
3939
y_threshold = None,
4040
y_asymptotic = np.nan,
@@ -52,10 +52,10 @@ def interpolate_ev(
5252
1-D array of x-values of training data
5353
y_train : array_like
5454
1-D array of y-values of training data
55-
x_scale : str, optional
56-
If set to 'log', x_values are convert to log scale. Defaults to None.
57-
y_scale : str, optional
58-
If set to 'log', x_values are convert to log scale. Defaults to None.
55+
logx : bool, optional
56+
If set to True, x_values are convert to log scale. Defaults to False.
57+
logy : bool, optional
58+
If set to True, x_values are convert to log scale. Defaults to False.
5959
x_threshold : float, optional
6060
Lower threshold to filter x_train. Defaults to None.
6161
y_threshold : float, optional
@@ -79,24 +79,24 @@ def interpolate_ev(
7979

8080
# preprocess interpolation data
8181
x_test, x_train, y_train = _preprocess_interpolation_data(
82-
x_test, x_train, y_train, x_scale, y_scale, x_threshold, y_threshold
82+
x_test, x_train, y_train, logx, logy, x_threshold, y_threshold
8383
)
8484

8585
# handle case of small training data sizes
8686
if x_train.size < 2:
8787
LOGGER.warning('Data is being extrapolated.')
88-
return _interpolate_small_input(x_test, x_train, y_train, y_scale, y_asymptotic)
88+
return _interpolate_small_input(x_test, x_train, y_train, logy, y_asymptotic)
8989

9090
# calculate fill values
9191
if isinstance(fill_value, tuple):
9292
if fill_value[0] == 'maximum':
9393
fill_value = (
9494
np.max(y_train),
95-
np.log10(fill_value[1]) if y_scale == 'log' else fill_value[1]
95+
np.log10(fill_value[1]) if logy else fill_value[1]
9696
)
97-
elif y_scale == 'log':
97+
elif logy:
9898
fill_value = tuple(np.log10(fill_value))
99-
elif isinstance(fill_value, (float, int)) and y_scale == 'log':
99+
elif isinstance(fill_value, (float, int)) and logy:
100100
fill_value = np.log10(fill_value)
101101

102102
# warn if data is being extrapolated
@@ -111,7 +111,7 @@ def interpolate_ev(
111111
y_test = interpolation(x_test)
112112

113113
# adapt output scale
114-
if y_scale == 'log':
114+
if logy:
115115
y_test = np.power(10., y_test)
116116
return y_test
117117

@@ -170,8 +170,8 @@ def _preprocess_interpolation_data(
170170
x_test,
171171
x_train,
172172
y_train,
173-
x_scale,
174-
y_scale,
173+
logx,
174+
logy,
175175
x_threshold,
176176
y_threshold
177177
):
@@ -201,23 +201,23 @@ def _preprocess_interpolation_data(
201201
y_train = y_train[y_th]
202202

203203
# convert to log scale
204-
if x_scale == 'log':
204+
if logx:
205205
x_train, x_test = np.log10(x_train), np.log10(x_test)
206-
if y_scale == 'log':
206+
if logy:
207207
y_train = np.log10(y_train)
208208

209209
return (x_test, x_train, y_train)
210210

211-
def _interpolate_small_input(x_test, x_train, y_train, y_scale, y_asymptotic):
211+
def _interpolate_small_input(x_test, x_train, y_train, logy, y_asymptotic):
212212
"""
213213
helper function to handle if interpolation data is small (empty or one point)
214214
"""
215215
# return y_asymptotic if x_train and y_train empty
216216
if x_train.size == 0:
217217
return np.full_like(x_test, y_asymptotic)
218218

219-
# reconvert logarithmic y_scale to normal y_train
220-
if y_scale == 'log':
219+
# reconvert logarithmic y_train to original y_train
220+
if logy:
221221
y_train = np.power(10., y_train)
222222

223223
# if only one (x_train, y_train), return stepfunction with

climada/util/test/test_interpolation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,29 +66,29 @@ def test_interpolate_ev_scale_parameters(self):
6666
y_train = np.array([1., 3.])
6767
x_test = np.array([1e0, 1e2])
6868
np.testing.assert_allclose(
69-
interpolate_ev(x_test, x_train, y_train, x_scale='log', fill_value='extrapolate'),
69+
interpolate_ev(x_test, x_train, y_train, logx=True, fill_value='extrapolate'),
7070
np.array([0., 2.])
7171
)
7272
np.testing.assert_allclose(
73-
interpolate_ev(x_test, x_train, y_train, x_scale='log'),
73+
interpolate_ev(x_test, x_train, y_train, logx=True),
7474
np.array([np.nan, 2.])
7575
)
7676
x_train = np.array([1., 3.])
7777
y_train = np.array([1e1, 1e3])
7878
x_test = np.array([0., 2.])
7979
np.testing.assert_allclose(
80-
interpolate_ev(x_test, x_train, y_train, y_scale='log', fill_value='extrapolate'),
80+
interpolate_ev(x_test, x_train, y_train, logy=True, fill_value='extrapolate'),
8181
np.array([1e0, 1e2])
8282
)
8383
np.testing.assert_allclose(
84-
interpolate_ev(x_test, x_train, y_train, y_scale='log'),
84+
interpolate_ev(x_test, x_train, y_train, logy=True),
8585
np.array([np.nan, 1e2])
8686
)
8787
x_train = np.array([1e1, 1e3])
8888
y_train = np.array([1e1, 1e5])
8989
x_test = np.array([1e0, 1e2])
9090
np.testing.assert_allclose(
91-
interpolate_ev(x_test, x_train, y_train, x_scale='log', y_scale='log', fill_value='extrapolate'),
91+
interpolate_ev(x_test, x_train, y_train, logx=True, logy=True, fill_value='extrapolate'),
9292
np.array([1e-1, 1e3])
9393
)
9494

0 commit comments

Comments
 (0)