Skip to content

Commit 257ae95

Browse files
authored
Merge pull request #4 from pelson/linear_extrap_and_auto_rise
Tidy some of the interface and expose a PyFuncInteprolator/PyFuncExtrapolator
2 parents e7f33f6 + f2cb8de commit 257ae95

File tree

3 files changed

+148
-88
lines changed

3 files changed

+148
-88
lines changed

stratify/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import absolute_import
22

3-
from ._vinterp import (interpolate,
3+
from ._vinterp import (interpolate, interp_schemes, extrap_schemes,
44
INTERPOLATE_LINEAR, INTERPOLATE_NEAREST,
55
EXTRAPOLATE_NAN, EXTRAPOLATE_NEAREST,
6-
EXTRAPOLATE_LINEAR)
6+
EXTRAPOLATE_LINEAR, PyFuncExtrapolator,
7+
PyFuncInterpolator)
78

stratify/_vinterp.pyx

Lines changed: 92 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
4545
@cython.wraparound(False)
4646
cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
4747
double[:, :] fz_src, bint increasing,
48-
InterpKernel interpolation,
49-
ExtrapKernel extrapolation,
48+
Interpolator interpolation,
49+
Extrapolator extrapolation,
5050
double [:, :] fz_target) nogil except -1:
5151
"""
5252
Computes the interpolation of multiple levels of a single column.
@@ -58,9 +58,9 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
5858
fz_src - the data to use for the actual interpolation
5959
increasing - true when increasing Z index generally implies increasing Z values
6060
interpolation - the inner interpolation functionality. See the definition of
61-
InterpKernel.
61+
Interpolator.
6262
extrapolation - the inner extrapolation functionality. See the definition of
63-
ExtrapKernel.
63+
Extrapolator.
6464
fz_target - the pre-allocated array to be used for the outputting the result
6565
of interpolation.
6666
@@ -166,7 +166,7 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
166166
z_before = z_current
167167

168168

169-
cdef class InterpKernel(object):
169+
cdef class Interpolator(object):
170170
cdef long kernel(self, unsigned int index,
171171
double[:] z_src, double[:, :] fz_src,
172172
double level, double[:] fz_level
@@ -189,7 +189,7 @@ cdef class InterpKernel(object):
189189
190190
"""
191191
with gil:
192-
raise RuntimeError('InterpKernel subclasses should implement '
192+
raise RuntimeError('Interpolator subclasses should implement '
193193
'the kernel function.')
194194

195195
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
@@ -198,7 +198,7 @@ cdef class InterpKernel(object):
198198
pass
199199

200200

201-
cdef class _LinearInterpKernel(InterpKernel):
201+
cdef class LinearInterpolator(Interpolator):
202202
@cython.boundscheck(False)
203203
@cython.wraparound(False)
204204
cdef long kernel(self, unsigned int index, double[:] z_src, double[:, :] fz_src, double level,
@@ -230,7 +230,7 @@ cdef class _LinearInterpKernel(InterpKernel):
230230
frac * (fz_src[i, index] - fz_src[i, index - 1])
231231

232232

233-
cdef class _NearestInterpKernel(InterpKernel):
233+
cdef class NearestNInterpolator(Interpolator):
234234
@cython.boundscheck(False)
235235
@cython.wraparound(False)
236236
cdef long kernel(self, unsigned int index, double[:] z_src, double[:, :] fz_src, double level,
@@ -248,20 +248,46 @@ cdef class _NearestInterpKernel(InterpKernel):
248248
fz_target[i] = fz_src[i, nearest_index]
249249

250250

251-
cdef class _TestableIndexInterpKernel(InterpKernel):
251+
cdef class PyFuncInterpolator(Interpolator):
252+
cdef bint use_column_prep
253+
254+
def __init__(self, use_column_prep=True):
255+
self.use_column_prep = use_column_prep
256+
257+
def column_prep(self, z_target, z_src, fz_src, increasing):
258+
"""
259+
Called each time this interpolator sees a new data array.
260+
This method may be used for validation of a column, or for column
261+
based pre-interpolation calculations (e.g. spline gradients).
262+
263+
Note: This method is not called if :attr:`.call_column_prep` is False.
264+
265+
"""
266+
pass
267+
268+
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
269+
double[:, :] fz_src, bint increasing) nogil except -1:
270+
if self.use_column_prep:
271+
with gil:
272+
self.column_prep(z_target, z_src, fz_src, increasing)
273+
274+
def interp_kernel(self, index, z_src, fz_src, level, output_array):
275+
# Fill the output array with the fz_src data at the given index.
276+
# This is nealy equivalent to nearest neighbour, but doesn't take
277+
# into account which neighbour is nearest.
278+
output_array[:] = fz_src[:, index]
279+
252280
@cython.boundscheck(False)
253281
@cython.wraparound(False)
254-
cdef long kernel(self, unsigned int index, double[:] z_src, double[:, :] fz_src,
255-
double level, double[:] fz_target) nogil except -1:
256-
# A simple, tesable interpolation, which simply returns the index of interpolation.
257-
cdef unsigned int m = fz_src.shape[0]
258-
cdef unsigned int i
259-
260-
for i in range(m):
261-
fz_target[i] = index
282+
cdef long kernel(self,
283+
unsigned int index, double[:] z_src,
284+
double[:, :] fz_src, double level,
285+
double[:] fz_target) nogil except -1:
286+
with gil:
287+
self.interp_kernel(index, z_src, fz_src, level, fz_target)
262288

263289

264-
cdef class ExtrapKernel(object):
290+
cdef class Extrapolator(object):
265291
cdef long kernel(self, int direction,
266292
double[:] z_src, double[:, :] fz_src,
267293
double current_level, double[:] fz_target
@@ -281,15 +307,15 @@ cdef class ExtrapKernel(object):
281307
extrapolated values into.
282308
"""
283309
with gil:
284-
raise RuntimeError('InterpKernel subclasses should implement '
310+
raise RuntimeError('Interpolator subclasses should implement '
285311
'the kernel function.')
286312

287313
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
288314
double[:, :] fz_src, bint increasing) nogil except -1:
289315
pass
290316

291317

292-
cdef class _NanExtrapKernel(ExtrapKernel):
318+
cdef class NaNExtrapolator(Extrapolator):
293319
@cython.boundscheck(False)
294320
@cython.wraparound(False)
295321
cdef long kernel(self, int direction, double[:] z_src,
@@ -303,13 +329,13 @@ cdef class _NanExtrapKernel(ExtrapKernel):
303329
fz_target[i] = NAN
304330

305331

306-
cdef class _NearestExtrapKernel(ExtrapKernel):
332+
cdef class NearestNExtrapolator(Extrapolator):
307333
@cython.boundscheck(False)
308334
@cython.wraparound(False)
309335
cdef long kernel(self,
310-
int direction, double[:] z_src,
311-
double[:, :] fz_src, double level,
312-
double[:] fz_target) nogil except -1:
336+
int direction, double[:] z_src,
337+
double[:, :] fz_src, double level,
338+
double[:] fz_target) nogil except -1:
313339
"""Nearest-neighbour/edge extrapolation."""
314340
cdef unsigned int m = fz_src.shape[0]
315341
cdef unsigned int index, i
@@ -323,7 +349,7 @@ cdef class _NearestExtrapKernel(ExtrapKernel):
323349
fz_target[i] = fz_src[i, index]
324350

325351

326-
cdef class _LinearExtrapKernel(ExtrapKernel):
352+
cdef class LinearExtrapolator(Extrapolator):
327353
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
328354
double[:, :] fz_src, bint increasing) nogil except -1:
329355
cdef unsigned int n_src_pts = z_src.shape[0]
@@ -342,21 +368,27 @@ cdef class _LinearExtrapKernel(ExtrapKernel):
342368
cdef unsigned int m = fz_src.shape[0]
343369
cdef unsigned int n_src_pts = fz_src.shape[1]
344370
cdef unsigned int p0, p1, i
345-
cdef double frac
371+
cdef double frac, z_step
346372

347373
if direction < 0:
348374
p0, p1 = 0, 1
349375
else:
350376
p0, p1 = n_src_pts - 2, n_src_pts - 1
351377

352-
frac = ((level - z_src[p0]) /
353-
(z_src[p1] - z_src[p0]))
378+
# Compute the normalised distance of the target point between p0 and p1
379+
z_step = z_src[p1] - z_src[p0]
380+
if z_step == 0:
381+
# If there is nothing between the last two points then we
382+
# extrapolate using a 0 gradient.
383+
frac = 0
384+
else:
385+
frac = ((level - z_src[p0]) / z_step)
354386

355387
for i in range(m):
356388
fz_target[i] = fz_src[i, p0] + frac * (fz_src[i, p1] - fz_src[i, p0])
357389

358390

359-
cdef class _PythonExtrapKernel(ExtrapKernel):
391+
cdef class PyFuncExtrapolator(Extrapolator):
360392
cdef bint use_column_prep
361393

362394
def __init__(self, use_column_prep=True):
@@ -392,38 +424,25 @@ cdef class _PythonExtrapKernel(ExtrapKernel):
392424
with gil:
393425
self.extrap_kernel(direction, z_src, fz_src, level, fz_target)
394426

395-
cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
396-
@cython.boundscheck(False)
397-
@cython.wraparound(False)
398-
cdef long kernel(self,
399-
int direction, double[:] z_src,
400-
double[:, :] fz_src, double level,
401-
double[:] fz_target) nogil except -1:
402-
# A simple testable extrapolation which simply returns
403-
# -inf if direction == -1 and inf if direction == 1.
404-
cdef unsigned int m = fz_src.shape[0]
405-
cdef double value
406-
cdef unsigned int i
407427

408-
if direction < 0:
409-
value = -INFINITY
410-
else:
411-
value = INFINITY
412-
for i in range(m):
413-
fz_target[i] = value
428+
interp_schemes = {'nearest': NearestNInterpolator,
429+
'linear': LinearInterpolator}
430+
431+
extrap_schemes = {'nearest': NearestNExtrapolator,
432+
'linear': LinearExtrapolator,
433+
'nan': NaNExtrapolator}
414434

415435

416436
# Construct interp/extrap constants exposed to the user.
417-
INTERPOLATE_LINEAR = _LinearInterpKernel()
418-
INTERPOLATE_NEAREST = _NearestInterpKernel()
419-
EXTRAPOLATE_NAN = _NanExtrapKernel()
420-
EXTRAPOLATE_NEAREST = _NearestExtrapKernel()
421-
EXTRAPOLATE_LINEAR = _LinearExtrapKernel()
437+
INTERPOLATE_LINEAR = interp_schemes['linear']()
438+
INTERPOLATE_NEAREST = interp_schemes['nearest']()
439+
EXTRAPOLATE_NAN = extrap_schemes['nan']()
440+
EXTRAPOLATE_NEAREST = extrap_schemes['nearest']()
441+
EXTRAPOLATE_LINEAR = extrap_schemes['linear']()
422442

423443

424444
def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
425-
interpolation=INTERPOLATE_LINEAR,
426-
extrapolation=EXTRAPOLATE_NAN):
445+
interpolation='linear', extrapolation='nan'):
427446
"""
428447
Interface for optimised 1d interpolation across multiple dimensions.
429448
@@ -466,31 +485,40 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
466485
If rising is None, the first two interpolation coordinate values
467486
will be used to determine the general direction. In most cases,
468487
this is a good option.
469-
interpolation: :class:`.InterpKernel` instance
488+
interpolation: :class:`.Interpolator` instance or valid scheme name
470489
The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
471490
and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
472491
iterpolation modes. Linear interpolation is the default.
473-
extrapolation: :class:`.ExtrapKernel` instance
492+
extrapolation: :class:`.Extrapolator` instance or valid scheme name
474493
The core extrapolation operation to use. :attr:`.EXTRAPOLATE_NAN` and
475494
:attr:`.EXTRAPOLATE_NEAREST` are provided for convenient extrapolation
476495
modes. NaN extrapolation is the default.
477496
478497
"""
479-
interp = _Interpolator(z_target, z_src, fz_src, rising=rising, axis=axis,
480-
interpolation=interpolation, extrapolation=extrapolation)
481-
return interp.interpolate()
498+
if interpolation in interp_schemes:
499+
interpolation = interp_schemes[interpolation]()
500+
if extrapolation in extrap_schemes:
501+
extrapolation = extrap_schemes[extrapolation]()
502+
503+
interp = _Interpolation(z_target, z_src, fz_src, rising=rising, axis=axis,
504+
interpolation=interpolation,
505+
extrapolation=extrapolation)
506+
if interp.z_target.ndim == 1:
507+
return interp.interpolate()
508+
else:
509+
return interp.interpolate_z_target_nd()
482510

483511

484-
cdef class _Interpolator(object):
512+
cdef class _Interpolation(object):
485513
"""
486514
Where the magic happens for gridwise_interp. The work of this __init__ is
487515
mostly for putting the input nd arrays into a 3 and 4 dimensional form for
488516
convenient (read: efficient) Cython form. Inline comments should help with
489517
understanding.
490518
491519
"""
492-
cdef InterpKernel interpolation
493-
cdef ExtrapKernel extrapolation
520+
cdef Interpolator interpolation
521+
cdef Extrapolator extrapolation
494522

495523
cdef public np.dtype _target_dtype
496524
cdef int rising
@@ -499,8 +527,8 @@ cdef class _Interpolator(object):
499527

500528
def __init__(self, z_target, z_src, fz_src, axis=-1,
501529
rising=None,
502-
InterpKernel interpolation=INTERPOLATE_LINEAR,
503-
ExtrapKernel extrapolation=EXTRAPOLATE_NAN):
530+
Interpolator interpolation=INTERPOLATE_LINEAR,
531+
Extrapolator extrapolation=EXTRAPOLATE_NAN):
504532
# Cast data to numpy arrays if not already.
505533
z_target = np.array(z_target, dtype=np.float64)
506534
z_src = np.array(z_src, dtype=np.float64)

0 commit comments

Comments
 (0)