@@ -45,8 +45,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
4545@ cython.wraparound (False )
4646cdef long gridwise_interpolation(double [:] z_target, double [:] z_src,
4747 double [:, :] fz_src, bint increasing,
48- InterpKernel interpolation,
49- ExtrapKernel extrapolation,
48+ Interpolator interpolation,
49+ Extrapolator extrapolation,
5050 double [:, :] fz_target) nogil except - 1 :
5151 """
5252 Computes the interpolation of multiple levels of a single column.
@@ -58,9 +58,9 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
5858 fz_src - the data to use for the actual interpolation
5959 increasing - true when increasing Z index generally implies increasing Z values
6060 interpolation - the inner interpolation functionality. See the definition of
61- InterpKernel .
61+ Interpolator .
6262 extrapolation - the inner extrapolation functionality. See the definition of
63- ExtrapKernel .
63+ Extrapolator .
6464 fz_target - the pre-allocated array to be used for the outputting the result
6565 of interpolation.
6666
@@ -166,7 +166,7 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
166166 z_before = z_current
167167
168168
169- cdef class InterpKernel (object ):
169+ cdef class Interpolator (object ):
170170 cdef long kernel(self , unsigned int index,
171171 double [:] z_src, double [:, :] fz_src,
172172 double level, double [:] fz_level
@@ -189,7 +189,7 @@ cdef class InterpKernel(object):
189189
190190 """
191191 with gil:
192- raise RuntimeError (' InterpKernel subclasses should implement '
192+ raise RuntimeError (' Interpolator subclasses should implement '
193193 ' the kernel function.' )
194194
195195 cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
@@ -198,7 +198,7 @@ cdef class InterpKernel(object):
198198 pass
199199
200200
201- cdef class _LinearInterpKernel(InterpKernel ):
201+ cdef class LinearInterpolator(Interpolator ):
202202 @ cython.boundscheck (False )
203203 @ cython.wraparound (False )
204204 cdef long kernel(self , unsigned int index, double [:] z_src, double [:, :] fz_src, double level,
@@ -230,7 +230,7 @@ cdef class _LinearInterpKernel(InterpKernel):
230230 frac * (fz_src[i, index] - fz_src[i, index - 1 ])
231231
232232
233- cdef class _NearestInterpKernel(InterpKernel ):
233+ cdef class NearestNInterpolator(Interpolator ):
234234 @ cython.boundscheck (False )
235235 @ cython.wraparound (False )
236236 cdef long kernel(self , unsigned int index, double [:] z_src, double [:, :] fz_src, double level,
@@ -248,20 +248,46 @@ cdef class _NearestInterpKernel(InterpKernel):
248248 fz_target[i] = fz_src[i, nearest_index]
249249
250250
251- cdef class _TestableIndexInterpKernel(InterpKernel):
251+ cdef class PyFuncInterpolator(Interpolator):
252+ cdef bint use_column_prep
253+
254+ def __init__ (self , use_column_prep = True ):
255+ self .use_column_prep = use_column_prep
256+
257+ def column_prep (self , z_target , z_src , fz_src , increasing ):
258+ """
259+ Called each time this interpolator sees a new data array.
260+ This method may be used for validation of a column, or for column
261+ based pre-interpolation calculations (e.g. spline gradients).
262+
263+ Note: This method is not called if :attr:`.call_column_prep` is False.
264+
265+ """
266+ pass
267+
268+ cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
269+ double [:, :] fz_src, bint increasing) nogil except - 1 :
270+ if self .use_column_prep:
271+ with gil:
272+ self .column_prep(z_target, z_src, fz_src, increasing)
273+
274+ def interp_kernel (self , index , z_src , fz_src , level , output_array ):
275+ # Fill the output array with the fz_src data at the given index.
276+ # This is nealy equivalent to nearest neighbour, but doesn't take
277+ # into account which neighbour is nearest.
278+ output_array[:] = fz_src[:, index]
279+
252280 @ cython.boundscheck (False )
253281 @ cython.wraparound (False )
254- cdef long kernel(self , unsigned int index, double [:] z_src, double [:, :] fz_src,
255- double level, double [:] fz_target) nogil except - 1 :
256- # A simple, tesable interpolation, which simply returns the index of interpolation.
257- cdef unsigned int m = fz_src.shape[0 ]
258- cdef unsigned int i
259-
260- for i in range (m):
261- fz_target[i] = index
282+ cdef long kernel(self ,
283+ unsigned int index, double [:] z_src,
284+ double [:, :] fz_src, double level,
285+ double [:] fz_target) nogil except - 1 :
286+ with gil:
287+ self .interp_kernel(index, z_src, fz_src, level, fz_target)
262288
263289
264- cdef class ExtrapKernel (object ):
290+ cdef class Extrapolator (object ):
265291 cdef long kernel(self , int direction,
266292 double [:] z_src, double [:, :] fz_src,
267293 double current_level, double [:] fz_target
@@ -281,15 +307,15 @@ cdef class ExtrapKernel(object):
281307 extrapolated values into.
282308 """
283309 with gil:
284- raise RuntimeError (' InterpKernel subclasses should implement '
310+ raise RuntimeError (' Interpolator subclasses should implement '
285311 ' the kernel function.' )
286312
287313 cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
288314 double [:, :] fz_src, bint increasing) nogil except - 1 :
289315 pass
290316
291317
292- cdef class _NanExtrapKernel(ExtrapKernel ):
318+ cdef class NaNExtrapolator(Extrapolator ):
293319 @ cython.boundscheck (False )
294320 @ cython.wraparound (False )
295321 cdef long kernel(self , int direction, double [:] z_src,
@@ -303,13 +329,13 @@ cdef class _NanExtrapKernel(ExtrapKernel):
303329 fz_target[i] = NAN
304330
305331
306- cdef class _NearestExtrapKernel(ExtrapKernel ):
332+ cdef class NearestNExtrapolator(Extrapolator ):
307333 @ cython.boundscheck (False )
308334 @ cython.wraparound (False )
309335 cdef long kernel(self ,
310- int direction, double [:] z_src,
311- double [:, :] fz_src, double level,
312- double [:] fz_target) nogil except - 1 :
336+ int direction, double [:] z_src,
337+ double [:, :] fz_src, double level,
338+ double [:] fz_target) nogil except - 1 :
313339 """ Nearest-neighbour/edge extrapolation."""
314340 cdef unsigned int m = fz_src.shape[0 ]
315341 cdef unsigned int index, i
@@ -323,7 +349,7 @@ cdef class _NearestExtrapKernel(ExtrapKernel):
323349 fz_target[i] = fz_src[i, index]
324350
325351
326- cdef class _LinearExtrapKernel(ExtrapKernel ):
352+ cdef class LinearExtrapolator(Extrapolator ):
327353 cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
328354 double [:, :] fz_src, bint increasing) nogil except - 1 :
329355 cdef unsigned int n_src_pts = z_src.shape[0 ]
@@ -342,21 +368,27 @@ cdef class _LinearExtrapKernel(ExtrapKernel):
342368 cdef unsigned int m = fz_src.shape[0 ]
343369 cdef unsigned int n_src_pts = fz_src.shape[1 ]
344370 cdef unsigned int p0, p1, i
345- cdef double frac
371+ cdef double frac, z_step
346372
347373 if direction < 0 :
348374 p0, p1 = 0 , 1
349375 else :
350376 p0, p1 = n_src_pts - 2 , n_src_pts - 1
351377
352- frac = ((level - z_src[p0]) /
353- (z_src[p1] - z_src[p0]))
378+ # Compute the normalised distance of the target point between p0 and p1
379+ z_step = z_src[p1] - z_src[p0]
380+ if z_step == 0 :
381+ # If there is nothing between the last two points then we
382+ # extrapolate using a 0 gradient.
383+ frac = 0
384+ else :
385+ frac = ((level - z_src[p0]) / z_step)
354386
355387 for i in range (m):
356388 fz_target[i] = fz_src[i, p0] + frac * (fz_src[i, p1] - fz_src[i, p0])
357389
358390
359- cdef class _PythonExtrapKernel(ExtrapKernel ):
391+ cdef class PyFuncExtrapolator(Extrapolator ):
360392 cdef bint use_column_prep
361393
362394 def __init__ (self , use_column_prep = True ):
@@ -392,38 +424,25 @@ cdef class _PythonExtrapKernel(ExtrapKernel):
392424 with gil:
393425 self .extrap_kernel(direction, z_src, fz_src, level, fz_target)
394426
395- cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
396- @ cython.boundscheck (False )
397- @ cython.wraparound (False )
398- cdef long kernel(self ,
399- int direction, double [:] z_src,
400- double [:, :] fz_src, double level,
401- double [:] fz_target) nogil except - 1 :
402- # A simple testable extrapolation which simply returns
403- # -inf if direction == -1 and inf if direction == 1.
404- cdef unsigned int m = fz_src.shape[0 ]
405- cdef double value
406- cdef unsigned int i
407427
408- if direction < 0 :
409- value = - INFINITY
410- else :
411- value = INFINITY
412- for i in range (m):
413- fz_target[i] = value
428+ interp_schemes = { ' nearest ' : NearestNInterpolator,
429+ ' linear ' : LinearInterpolator}
430+
431+ extrap_schemes = { ' nearest ' : NearestNExtrapolator,
432+ ' linear ' : LinearExtrapolator,
433+ ' nan ' : NaNExtrapolator}
414434
415435
416436# Construct interp/extrap constants exposed to the user.
417- INTERPOLATE_LINEAR = _LinearInterpKernel ()
418- INTERPOLATE_NEAREST = _NearestInterpKernel ()
419- EXTRAPOLATE_NAN = _NanExtrapKernel ()
420- EXTRAPOLATE_NEAREST = _NearestExtrapKernel ()
421- EXTRAPOLATE_LINEAR = _LinearExtrapKernel ()
437+ INTERPOLATE_LINEAR = interp_schemes[ ' linear ' ] ()
438+ INTERPOLATE_NEAREST = interp_schemes[ ' nearest ' ] ()
439+ EXTRAPOLATE_NAN = extrap_schemes[ ' nan ' ] ()
440+ EXTRAPOLATE_NEAREST = extrap_schemes[ ' nearest ' ] ()
441+ EXTRAPOLATE_LINEAR = extrap_schemes[ ' linear ' ] ()
422442
423443
424444def interpolate (z_target , z_src , fz_src , axis = - 1 , rising = None ,
425- interpolation = INTERPOLATE_LINEAR,
426- extrapolation = EXTRAPOLATE_NAN):
445+ interpolation = ' linear' , extrapolation = ' nan' ):
427446 """
428447 Interface for optimised 1d interpolation across multiple dimensions.
429448
@@ -466,31 +485,40 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
466485 If rising is None, the first two interpolation coordinate values
467486 will be used to determine the general direction. In most cases,
468487 this is a good option.
469- interpolation: :class:`.InterpKernel ` instance
488+ interpolation: :class:`.Interpolator ` instance or valid scheme name
470489 The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
471490 and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
472491 iterpolation modes. Linear interpolation is the default.
473- extrapolation: :class:`.ExtrapKernel ` instance
492+ extrapolation: :class:`.Extrapolator ` instance or valid scheme name
474493 The core extrapolation operation to use. :attr:`.EXTRAPOLATE_NAN` and
475494 :attr:`.EXTRAPOLATE_NEAREST` are provided for convenient extrapolation
476495 modes. NaN extrapolation is the default.
477496
478497 """
479- interp = _Interpolator(z_target, z_src, fz_src, rising = rising, axis = axis,
480- interpolation = interpolation, extrapolation = extrapolation)
481- return interp.interpolate()
498+ if interpolation in interp_schemes:
499+ interpolation = interp_schemes[interpolation]()
500+ if extrapolation in extrap_schemes:
501+ extrapolation = extrap_schemes[extrapolation]()
502+
503+ interp = _Interpolation(z_target, z_src, fz_src, rising = rising, axis = axis,
504+ interpolation = interpolation,
505+ extrapolation = extrapolation)
506+ if interp.z_target.ndim == 1 :
507+ return interp.interpolate()
508+ else :
509+ return interp.interpolate_z_target_nd()
482510
483511
484- cdef class _Interpolator (object ):
512+ cdef class _Interpolation (object ):
485513 """
486514 Where the magic happens for gridwise_interp. The work of this __init__ is
487515 mostly for putting the input nd arrays into a 3 and 4 dimensional form for
488516 convenient (read: efficient) Cython form. Inline comments should help with
489517 understanding.
490518
491519 """
492- cdef InterpKernel interpolation
493- cdef ExtrapKernel extrapolation
520+ cdef Interpolator interpolation
521+ cdef Extrapolator extrapolation
494522
495523 cdef public np.dtype _target_dtype
496524 cdef int rising
@@ -499,8 +527,8 @@ cdef class _Interpolator(object):
499527
500528 def __init__ (self , z_target , z_src , fz_src , axis = - 1 ,
501529 rising = None ,
502- InterpKernel interpolation = INTERPOLATE_LINEAR,
503- ExtrapKernel extrapolation = EXTRAPOLATE_NAN):
530+ Interpolator interpolation = INTERPOLATE_LINEAR,
531+ Extrapolator extrapolation = EXTRAPOLATE_NAN):
504532 # Cast data to numpy arrays if not already.
505533 z_target = np.array(z_target, dtype = np.float64)
506534 z_src = np.array(z_src, dtype = np.float64)
0 commit comments