Skip to content

Commit 1a872d7

Browse files
committed
Added automatic rising detection, and a linear extrapolation kernel.
1 parent b201615 commit 1a872d7

File tree

3 files changed

+103
-22
lines changed

3 files changed

+103
-22
lines changed

stratify/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
from ._vinterp import (interpolate,
44
INTERPOLATE_LINEAR, INTERPOLATE_NEAREST,
5-
EXTRAPOLATE_NAN, EXTRAPOLATE_NEAREST)
5+
EXTRAPOLATE_NAN, EXTRAPOLATE_NEAREST,
6+
EXTRAPOLATE_LINEAR)
67

stratify/_vinterp.pyx

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ cdef extern from "math.h" nogil:
1717

1818
__all__ = ['interpolate',
1919
'INTERPOLATE_LINEAR', 'INTERPOLATE_NEAREST',
20-
'EXTRAPOLATE_NAN', 'EXTRAPOLATE_NEAREST']
20+
'EXTRAPOLATE_NAN', 'EXTRAPOLATE_NEAREST', 'EXTRAPOLATE_LINEAR']
2121

2222

2323
# interp_kernel defines the inner part of an interpolation operation.
@@ -270,6 +270,28 @@ cdef long nearest_edge_extrap(int direction, double[:] z_src,
270270
fz_target[i] = fz_src[i, index]
271271

272272

273+
@cython.boundscheck(False)
274+
@cython.wraparound(False)
275+
cdef long linear_extrap(int direction, double[:] z_src,
276+
double[:, :] fz_src, double level,
277+
double[:] fz_target) nogil except -1:
278+
"""Linear extrapolation using either the first or last 2 values."""
279+
cdef unsigned int m = fz_src.shape[0]
280+
cdef unsigned int p0, p1, i
281+
cdef double frac
282+
283+
if direction < 0:
284+
p0, p1 = 0, 1
285+
else:
286+
p0, p1 = fz_src.shape[1] - 2, fz_src.shape[1] - 1
287+
288+
frac = ((level - z_src[p0]) /
289+
(z_src[p1] - z_src[p0]))
290+
291+
for i in range(m):
292+
fz_target[i] = fz_src[i, p0] + frac * (fz_src[i, p1] - fz_src[i, p0])
293+
294+
273295
@cython.boundscheck(False)
274296
@cython.wraparound(False)
275297
cdef long nan_data_extrap(int direction, double[:] z_src,
@@ -335,6 +357,11 @@ cdef class _NearestExtrapKernel(ExtrapKernel):
335357
self.kernel = nearest_edge_extrap
336358

337359

360+
cdef class _LinearExtrapKernel(ExtrapKernel):
361+
def __init__(self):
362+
self.kernel = linear_extrap
363+
364+
338365
cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
339366
def __init__(self):
340367
self.kernel = _testable_direction_extrap
@@ -345,9 +372,10 @@ INTERPOLATE_LINEAR = _LinearInterpKernel()
345372
INTERPOLATE_NEAREST = _NearestInterpKernel()
346373
EXTRAPOLATE_NAN = _NanExtrapKernel()
347374
EXTRAPOLATE_NEAREST = _NearestExtrapKernel()
375+
EXTRAPOLATE_LINEAR = _LinearExtrapKernel()
348376

349377

350-
def interpolate(z_target, z_src, fz_src, axis=-1, rising=True,
378+
def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
351379
interpolation=INTERPOLATE_LINEAR,
352380
extrapolation=EXTRAPOLATE_NAN):
353381
"""
@@ -382,12 +410,16 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=True,
382410
the same as the shape of ``z_src``.
383411
axis: int (default -1)
384412
The axis to perform the interpolation over.
385-
rising: bool (default True)
386-
Whether the values of the target z coordinate are generally rising or
387-
generally falling. For example, values of pressure levels will be
388-
generally falling as the z coordinate increases.
413+
rising: bool (default None)
414+
Whether the values of the source's interpolation coordinate values
415+
are generally rising or generally falling. For example, values of
416+
pressure levels will be generally falling as the z coordinate
417+
increases.
389418
This will determine whether extrapolation needs to occur for
390419
``z_target`` below the first and above the last ``z_src``.
420+
If rising is None, the first two interpolation coordinate values
421+
will be used to determine the general direction. In most cases,
422+
this is a good option.
391423
interpolation: :class:`.InterpKernel` instance
392424
The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
393425
and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
@@ -407,7 +439,7 @@ cdef class _Interpolator(object):
407439
"""
408440
Where the magic happens for gridwise_interp. The work of this __init__ is
409441
mostly for putting the input nd arrays into a 3 and 4 dimensional form for
410-
convenient (read efficient) Cython form. Inline comments should help with
442+
convenient (read: efficient) Cython form. Inline comments should help with
411443
understanding.
412444
413445
"""
@@ -420,7 +452,7 @@ cdef class _Interpolator(object):
420452
cpdef public _result_working_shape, result_shape, _first_value
421453

422454
def __init__(self, z_target, z_src, fz_src, axis=-1,
423-
bint rising=True,
455+
rising=None,
424456
InterpKernel interpolation=INTERPOLATE_LINEAR,
425457
ExtrapKernel extrapolation=EXTRAPOLATE_NAN):
426458
# Cast data to numpy arrays if not already.
@@ -496,7 +528,17 @@ cdef class _Interpolator(object):
496528
#: The shape of the interpolated data.
497529
self.result_shape = tuple(result_shape)
498530

499-
self.rising = rising
531+
if rising is None:
532+
if z_src.shape[zp_axis] < 2:
533+
raise ValueError('The rising keyword must be defined when '
534+
'the size of the source array is <2 in '
535+
'the interpolation axis.')
536+
z_src_indexer = [0] * z_src.ndim
537+
z_src_indexer[zp_axis] = slice(0, 2)
538+
first_two = z_src[z_src_indexer]
539+
rising = first_two[0] <= first_two[1]
540+
541+
self.rising = bool(rising)
500542

501543
self.interpolation = interpolation.kernel
502544
self.extrapolation = extrapolation.kernel
@@ -562,7 +604,7 @@ cdef class _Interpolator(object):
562604
gridwise_interpolation(z_target[i, :, j], z_src[i, :, j], fz_src[:, i, :, j],
563605
self.rising,
564606
self.interpolation,
565-
self.extrapolation,
607+
self.extrapolation,
566608
fz_target_view[:, i, :, j])
567609

568610
return fz_target.reshape(self.result_shape).astype(self._target_dtype)

stratify/tests/test_vinterp.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class TestColumnInterpolation(unittest.TestCase):
13-
def interpolate(self, x_target, x_src):
13+
def interpolate(self, x_target, x_src, rising=None):
1414
x_target = np.array(x_target)
1515
x_src = np.array(x_src)
1616
fx_src = np.empty(x_src.shape)
@@ -19,13 +19,15 @@ def interpolate(self, x_target, x_src):
1919
extrap_direct = vinterp._TestableDirectionExtrapKernel()
2020

2121
r1 = stratify.interpolate(x_target, x_src, fx_src,
22-
interpolation=index_interp,
23-
extrapolation=extrap_direct)
22+
rising=rising,
23+
interpolation=index_interp,
24+
extrapolation=extrap_direct)
2425

25-
r2 = stratify.interpolate(-1 * x_target, -1 * x_src, fx_src,
26-
rising=False, interpolation=index_interp,
27-
extrapolation=extrap_direct)
28-
assert_array_equal(r1, r2)
26+
if rising is not None:
27+
r2 = stratify.interpolate(-1 * x_target, -1 * x_src, fx_src,
28+
rising=not rising, interpolation=index_interp,
29+
extrapolation=extrap_direct)
30+
assert_array_equal(r1, r2)
2931

3032
return r1
3133

@@ -46,7 +48,7 @@ def test_lower_extrap_only(self):
4648
assert_array_equal(r, [-np.inf, -np.inf, -np.inf])
4749

4850
def test_upper_extrap_only(self):
49-
r = self.interpolate([1, 2, 3], [-4, -5])
51+
r = self.interpolate([1, 2, 3], [-4, -5], rising=True)
5052
assert_array_equal(r, [np.inf, np.inf, np.inf])
5153

5254
def test_extrap_on_both_sides_only(self):
@@ -65,7 +67,7 @@ def test_nan_in_target(self):
6567
def test_nan_in_src(self):
6668
msg = 'The source coordinate .* NaN'
6769
with self.assertRaisesRegexp(ValueError, msg):
68-
self.interpolate([1], [0, np.nan])
70+
self.interpolate([1], [0, np.nan], rising=True)
6971

7072
def test_all_nan_in_src(self):
7173
r = self.interpolate([1, 2, 3, 4], [np.nan, np.nan, np.nan])
@@ -86,13 +88,13 @@ def test_wrong_rising_target(self):
8688
assert_array_equal(r, [1, np.inf])
8789

8890
def test_wrong_rising_source(self):
89-
r = self.interpolate([1, 2], [2, 1])
91+
r = self.interpolate([1, 2], [2, 1], rising=True)
9092
assert_array_equal(r, [-np.inf, 0])
9193

9294
def test_wrong_rising_source_and_target(self):
9395
# If we overshoot the first level, there is no hope,
9496
# so we end up extrapolating.
95-
r = self.interpolate([3, 2, 1, 0], [2, 1])
97+
r = self.interpolate([3, 2, 1, 0], [2, 1], rising=True)
9698
assert_array_equal(r, [np.inf, np.inf, np.inf, np.inf])
9799

98100
def test_non_monotonic_coordinate_interp(self):
@@ -103,6 +105,20 @@ def test_non_monotonic_coordinate_extrap(self):
103105
result = self.interpolate([0, 15, 16, 17, 5, 15., 25], [10., 40, 0, 20])
104106
assert_array_equal(result, [-np.inf, 1, 1, 1, 2, 3, np.inf])
105107

108+
def test_length_one_interp(self):
109+
r = self.interpolate([1], [2], rising=True)
110+
assert_array_equal(r, [-np.inf])
111+
112+
def test_auto_rising_not_enough_values(self):
113+
with self.assertRaises(ValueError):
114+
r = self.interpolate([1], [2])
115+
116+
def test_auto_rising_equal_values(self):
117+
# The code checks whether the first value is <= or equal to
118+
# the second. If it didn't, we'd end up with +inf, not -inf.
119+
r = self.interpolate([1], [2, 2])
120+
assert_array_equal(r, [-np.inf])
121+
106122

107123
class Test_INTERPOLATE_LINEAR(unittest.TestCase):
108124
def interpolate(self, x_target):
@@ -197,6 +213,28 @@ def test_above(self):
197213
assert_array_equal(self.interpolate([5]), [40])
198214

199215

216+
class Test_EXTRAPOLATE_LINEAR(unittest.TestCase):
217+
def interpolate(self, x_target):
218+
interpolation = vinterp._TestableIndexInterpKernel()
219+
extrapolation = stratify.EXTRAPOLATE_LINEAR
220+
221+
x_src = np.arange(5)
222+
# To spice things up a bit, let's make x_src non-equal distance.
223+
x_src[4] = 9
224+
fx_src = 10 * x_src
225+
226+
# Use -2 to test negative number support.
227+
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
228+
interpolation=interpolation,
229+
extrapolation=extrapolation)
230+
231+
def test_below(self):
232+
assert_array_equal(self.interpolate([-1]), [-10.])
233+
234+
def test_above(self):
235+
assert_array_almost_equal(self.interpolate([15.123]), [151.23])
236+
237+
200238
class Test__Interpolator(unittest.TestCase):
201239
def test_axis_m1(self):
202240
data = np.empty([5, 4, 23, 7, 3])

0 commit comments

Comments
 (0)