Skip to content

Commit adf7594

Browse files
authored
Merge pull request #1 from pelson/linear_extrap_and_auto_rise
Added automatic rising detection, and a linear extrapolation kernel.
2 parents b201615 + 502cee1 commit adf7594

File tree

3 files changed

+140
-28
lines changed

3 files changed

+140
-28
lines changed

stratify/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
from ._vinterp import (interpolate,
44
INTERPOLATE_LINEAR, INTERPOLATE_NEAREST,
5-
EXTRAPOLATE_NAN, EXTRAPOLATE_NEAREST)
5+
EXTRAPOLATE_NAN, EXTRAPOLATE_NEAREST,
6+
EXTRAPOLATE_LINEAR)
67

stratify/_vinterp.pyx

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ cdef extern from "math.h" nogil:
1717

1818
__all__ = ['interpolate',
1919
'INTERPOLATE_LINEAR', 'INTERPOLATE_NEAREST',
20-
'EXTRAPOLATE_NAN', 'EXTRAPOLATE_NEAREST']
20+
'EXTRAPOLATE_NAN', 'EXTRAPOLATE_NEAREST', 'EXTRAPOLATE_LINEAR']
2121

2222

2323
# interp_kernel defines the inner part of an interpolation operation.
@@ -270,6 +270,34 @@ cdef long nearest_edge_extrap(int direction, double[:] z_src,
270270
fz_target[i] = fz_src[i, index]
271271

272272

273+
@cython.boundscheck(False)
274+
@cython.wraparound(False)
275+
cdef long linear_extrap(int direction, double[:] z_src,
276+
double[:, :] fz_src, double level,
277+
double[:] fz_target) nogil except -1:
278+
"""Linear extrapolation using either the first or last 2 values."""
279+
cdef unsigned int m = fz_src.shape[0]
280+
cdef unsigned int n_src_pts = fz_src.shape[1]
281+
cdef unsigned int p0, p1, i
282+
cdef double frac
283+
284+
if n_src_pts < 2:
285+
with gil:
286+
raise ValueError('Linear extrapolation requires at least '
287+
'2 source points. Got {}.'.format(n_src_pts))
288+
289+
if direction < 0:
290+
p0, p1 = 0, 1
291+
else:
292+
p0, p1 = n_src_pts - 2, n_src_pts - 1
293+
294+
frac = ((level - z_src[p0]) /
295+
(z_src[p1] - z_src[p0]))
296+
297+
for i in range(m):
298+
fz_target[i] = fz_src[i, p0] + frac * (fz_src[i, p1] - fz_src[i, p0])
299+
300+
273301
@cython.boundscheck(False)
274302
@cython.wraparound(False)
275303
cdef long nan_data_extrap(int direction, double[:] z_src,
@@ -335,6 +363,11 @@ cdef class _NearestExtrapKernel(ExtrapKernel):
335363
self.kernel = nearest_edge_extrap
336364

337365

366+
cdef class _LinearExtrapKernel(ExtrapKernel):
367+
def __init__(self):
368+
self.kernel = linear_extrap
369+
370+
338371
cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
339372
def __init__(self):
340373
self.kernel = _testable_direction_extrap
@@ -345,9 +378,10 @@ INTERPOLATE_LINEAR = _LinearInterpKernel()
345378
INTERPOLATE_NEAREST = _NearestInterpKernel()
346379
EXTRAPOLATE_NAN = _NanExtrapKernel()
347380
EXTRAPOLATE_NEAREST = _NearestExtrapKernel()
381+
EXTRAPOLATE_LINEAR = _LinearExtrapKernel()
348382

349383

350-
def interpolate(z_target, z_src, fz_src, axis=-1, rising=True,
384+
def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
351385
interpolation=INTERPOLATE_LINEAR,
352386
extrapolation=EXTRAPOLATE_NAN):
353387
"""
@@ -382,12 +416,16 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=True,
382416
the same as the shape of ``z_src``.
383417
axis: int (default -1)
384418
The axis to perform the interpolation over.
385-
rising: bool (default True)
386-
Whether the values of the target z coordinate are generally rising or
387-
generally falling. For example, values of pressure levels will be
388-
generally falling as the z coordinate increases.
419+
rising: bool (default None)
420+
Whether the values of the source's interpolation coordinate values
421+
are generally rising or generally falling. For example, values of
422+
pressure levels will be generally falling as the z coordinate
423+
increases.
389424
This will determine whether extrapolation needs to occur for
390425
``z_target`` below the first and above the last ``z_src``.
426+
If rising is None, the first two interpolation coordinate values
427+
will be used to determine the general direction. In most cases,
428+
this is a good option.
391429
interpolation: :class:`.InterpKernel` instance
392430
The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
393431
and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
@@ -407,7 +445,7 @@ cdef class _Interpolator(object):
407445
"""
408446
Where the magic happens for gridwise_interp. The work of this __init__ is
409447
mostly for putting the input nd arrays into a 3 and 4 dimensional form for
410-
convenient (read efficient) Cython form. Inline comments should help with
448+
convenient (read: efficient) Cython form. Inline comments should help with
411449
understanding.
412450
413451
"""
@@ -420,7 +458,7 @@ cdef class _Interpolator(object):
420458
cpdef public _result_working_shape, result_shape, _first_value
421459

422460
def __init__(self, z_target, z_src, fz_src, axis=-1,
423-
bint rising=True,
461+
rising=None,
424462
InterpKernel interpolation=INTERPOLATE_LINEAR,
425463
ExtrapKernel extrapolation=EXTRAPOLATE_NAN):
426464
# Cast data to numpy arrays if not already.
@@ -496,7 +534,17 @@ cdef class _Interpolator(object):
496534
#: The shape of the interpolated data.
497535
self.result_shape = tuple(result_shape)
498536

499-
self.rising = rising
537+
if rising is None:
538+
if z_src.shape[zp_axis] < 2:
539+
raise ValueError('The rising keyword must be defined when '
540+
'the size of the source array is <2 in '
541+
'the interpolation axis.')
542+
z_src_indexer = [0] * z_src.ndim
543+
z_src_indexer[zp_axis] = slice(0, 2)
544+
first_two = z_src[z_src_indexer]
545+
rising = first_two[0] <= first_two[1]
546+
547+
self.rising = bool(rising)
500548

501549
self.interpolation = interpolation.kernel
502550
self.extrapolation = extrapolation.kernel
@@ -562,7 +610,7 @@ cdef class _Interpolator(object):
562610
gridwise_interpolation(z_target[i, :, j], z_src[i, :, j], fz_src[:, i, :, j],
563611
self.rising,
564612
self.interpolation,
565-
self.extrapolation,
613+
self.extrapolation,
566614
fz_target_view[:, i, :, j])
567615

568616
return fz_target.reshape(self.result_shape).astype(self._target_dtype)

stratify/tests/test_vinterp.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class TestColumnInterpolation(unittest.TestCase):
13-
def interpolate(self, x_target, x_src):
13+
def interpolate(self, x_target, x_src, rising=None):
1414
x_target = np.array(x_target)
1515
x_src = np.array(x_src)
1616
fx_src = np.empty(x_src.shape)
@@ -19,13 +19,15 @@ def interpolate(self, x_target, x_src):
1919
extrap_direct = vinterp._TestableDirectionExtrapKernel()
2020

2121
r1 = stratify.interpolate(x_target, x_src, fx_src,
22-
interpolation=index_interp,
23-
extrapolation=extrap_direct)
22+
rising=rising,
23+
interpolation=index_interp,
24+
extrapolation=extrap_direct)
2425

25-
r2 = stratify.interpolate(-1 * x_target, -1 * x_src, fx_src,
26-
rising=False, interpolation=index_interp,
27-
extrapolation=extrap_direct)
28-
assert_array_equal(r1, r2)
26+
if rising is not None:
27+
r2 = stratify.interpolate(-1 * x_target, -1 * x_src, fx_src,
28+
rising=not rising, interpolation=index_interp,
29+
extrapolation=extrap_direct)
30+
assert_array_equal(r1, r2)
2931

3032
return r1
3133

@@ -46,7 +48,7 @@ def test_lower_extrap_only(self):
4648
assert_array_equal(r, [-np.inf, -np.inf, -np.inf])
4749

4850
def test_upper_extrap_only(self):
49-
r = self.interpolate([1, 2, 3], [-4, -5])
51+
r = self.interpolate([1, 2, 3], [-4, -5], rising=True)
5052
assert_array_equal(r, [np.inf, np.inf, np.inf])
5153

5254
def test_extrap_on_both_sides_only(self):
@@ -65,7 +67,7 @@ def test_nan_in_target(self):
6567
def test_nan_in_src(self):
6668
msg = 'The source coordinate .* NaN'
6769
with self.assertRaisesRegexp(ValueError, msg):
68-
self.interpolate([1], [0, np.nan])
70+
self.interpolate([1], [0, np.nan], rising=True)
6971

7072
def test_all_nan_in_src(self):
7173
r = self.interpolate([1, 2, 3, 4], [np.nan, np.nan, np.nan])
@@ -86,13 +88,13 @@ def test_wrong_rising_target(self):
8688
assert_array_equal(r, [1, np.inf])
8789

8890
def test_wrong_rising_source(self):
89-
r = self.interpolate([1, 2], [2, 1])
91+
r = self.interpolate([1, 2], [2, 1], rising=True)
9092
assert_array_equal(r, [-np.inf, 0])
9193

9294
def test_wrong_rising_source_and_target(self):
9395
# If we overshoot the first level, there is no hope,
9496
# so we end up extrapolating.
95-
r = self.interpolate([3, 2, 1, 0], [2, 1])
97+
r = self.interpolate([3, 2, 1, 0], [2, 1], rising=True)
9698
assert_array_equal(r, [np.inf, np.inf, np.inf, np.inf])
9799

98100
def test_non_monotonic_coordinate_interp(self):
@@ -103,6 +105,20 @@ def test_non_monotonic_coordinate_extrap(self):
103105
result = self.interpolate([0, 15, 16, 17, 5, 15., 25], [10., 40, 0, 20])
104106
assert_array_equal(result, [-np.inf, 1, 1, 1, 2, 3, np.inf])
105107

108+
def test_length_one_interp(self):
109+
r = self.interpolate([1], [2], rising=True)
110+
assert_array_equal(r, [-np.inf])
111+
112+
def test_auto_rising_not_enough_values(self):
113+
with self.assertRaises(ValueError):
114+
r = self.interpolate([1], [2])
115+
116+
def test_auto_rising_equal_values(self):
117+
# The code checks whether the first value is <= or equal to
118+
# the second. If it didn't, we'd end up with +inf, not -inf.
119+
r = self.interpolate([1], [2, 2])
120+
assert_array_equal(r, [-np.inf])
121+
106122

107123
class Test_INTERPOLATE_LINEAR(unittest.TestCase):
108124
def interpolate(self, x_target):
@@ -129,6 +145,19 @@ def test_high_precision(self):
129145
assert_array_almost_equal(self.interpolate([1.123456789]),
130146
[11.23456789], decimal=6)
131147

148+
def test_single_point(self):
149+
# Test that a single input point that falls exactly on the target
150+
# level triggers a shortcut that avoids the expectation of >=2 source
151+
# points.
152+
interpolation = stratify.INTERPOLATE_LINEAR
153+
extrapolation = vinterp._TestableDirectionExtrapKernel()
154+
155+
r = stratify.interpolate([2], [2], [20],
156+
interpolation=interpolation,
157+
extrapolation=extrapolation,
158+
rising=True)
159+
self.assertEqual(r, 20)
160+
132161

133162
class Test_INTERPOLATE_NEAREST(unittest.TestCase):
134163
def interpolate(self, x_target):
@@ -140,8 +169,8 @@ def interpolate(self, x_target):
140169

141170
# Use -2 to test negative number support.
142171
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
143-
interpolation=interpolation,
144-
extrapolation=extrapolation)
172+
interpolation=interpolation,
173+
extrapolation=extrapolation)
145174

146175
def test_on_the_mark(self):
147176
assert_array_equal(self.interpolate([0, 1, 2, 3, 4]),
@@ -167,8 +196,8 @@ def interpolate(self, x_target):
167196

168197
# Use -2 to test negative number support.
169198
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
170-
interpolation=interpolation,
171-
extrapolation=extrapolation)
199+
interpolation=interpolation,
200+
extrapolation=extrapolation)
172201

173202
def test_below(self):
174203
assert_array_equal(self.interpolate([-1]), [np.nan])
@@ -187,8 +216,8 @@ def interpolate(self, x_target):
187216

188217
# Use -2 to test negative number support.
189218
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
190-
interpolation=interpolation,
191-
extrapolation=extrapolation)
219+
interpolation=interpolation,
220+
extrapolation=extrapolation)
192221

193222
def test_below(self):
194223
assert_array_equal(self.interpolate([-1]), [0.])
@@ -197,6 +226,40 @@ def test_above(self):
197226
assert_array_equal(self.interpolate([5]), [40])
198227

199228

229+
class Test_EXTRAPOLATE_LINEAR(unittest.TestCase):
230+
def interpolate(self, x_target):
231+
interpolation = vinterp._TestableIndexInterpKernel()
232+
extrapolation = stratify.EXTRAPOLATE_LINEAR
233+
234+
x_src = np.arange(5)
235+
# To spice things up a bit, let's make x_src non-equal distance.
236+
x_src[4] = 9
237+
fx_src = 10 * x_src
238+
239+
# Use -2 to test negative number support.
240+
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
241+
interpolation=interpolation,
242+
extrapolation=extrapolation)
243+
244+
def test_below(self):
245+
assert_array_equal(self.interpolate([-1]), [-10.])
246+
247+
def test_above(self):
248+
assert_array_almost_equal(self.interpolate([15.123]), [151.23])
249+
250+
def test_npts(self):
251+
interpolation = vinterp._TestableIndexInterpKernel()
252+
extrapolation = stratify.EXTRAPOLATE_LINEAR
253+
254+
msg = (r'Linear extrapolation requires at least 2 '
255+
r'source points. Got 1.')
256+
257+
with self.assertRaisesRegexp(ValueError, msg):
258+
stratify.interpolate([1, 3.], [2], [20],
259+
interpolation=interpolation,
260+
extrapolation=extrapolation, rising=True)
261+
262+
200263
class Test__Interpolator(unittest.TestCase):
201264
def test_axis_m1(self):
202265
data = np.empty([5, 4, 23, 7, 3])

0 commit comments

Comments
 (0)