@@ -45,8 +45,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
4545@ cython.wraparound (False )
4646cdef long gridwise_interpolation(double [:] z_target, double [:] z_src,
4747 double [:, :] fz_src, bint increasing,
48- InterpKernel interpolation,
49- ExtrapKernel extrapolation,
48+ Interpolator interpolation,
49+ Extrapolator extrapolation,
5050 double [:, :] fz_target) nogil except - 1 :
5151 """
5252 Computes the interpolation of multiple levels of a single column.
@@ -58,9 +58,9 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
5858 fz_src - the data to use for the actual interpolation
5959 increasing - true when increasing Z index generally implies increasing Z values
6060 interpolation - the inner interpolation functionality. See the definition of
61- InterpKernel .
61+ Interpolator .
6262 extrapolation - the inner extrapolation functionality. See the definition of
63- ExtrapKernel .
63+ Extrapolator .
6464 fz_target - the pre-allocated array to be used for the outputting the result
6565 of interpolation.
6666
@@ -166,7 +166,7 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
166166 z_before = z_current
167167
168168
169- cdef class InterpKernel (object ):
169+ cdef class Interpolator (object ):
170170 cdef long kernel(self , unsigned int index,
171171 double [:] z_src, double [:, :] fz_src,
172172 double level, double [:] fz_level
@@ -189,7 +189,7 @@ cdef class InterpKernel(object):
189189
190190 """
191191 with gil:
192- raise RuntimeError (' InterpKernel subclasses should implement '
192+ raise RuntimeError (' Interpolator subclasses should implement '
193193 ' the kernel function.' )
194194
195195 cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
@@ -198,7 +198,7 @@ cdef class InterpKernel(object):
198198 pass
199199
200200
201- cdef class _LinearInterpKernel(InterpKernel ):
201+ cdef class LinearInterpolator(Interpolator ):
202202 @ cython.boundscheck (False )
203203 @ cython.wraparound (False )
204204 cdef long kernel(self , unsigned int index, double [:] z_src, double [:, :] fz_src, double level,
@@ -230,7 +230,7 @@ cdef class _LinearInterpKernel(InterpKernel):
230230 frac * (fz_src[i, index] - fz_src[i, index - 1 ])
231231
232232
233- cdef class _NearestInterpKernel(InterpKernel ):
233+ cdef class NearestNInterpolator(Interpolator ):
234234 @ cython.boundscheck (False )
235235 @ cython.wraparound (False )
236236 cdef long kernel(self , unsigned int index, double [:] z_src, double [:, :] fz_src, double level,
@@ -248,20 +248,46 @@ cdef class _NearestInterpKernel(InterpKernel):
248248 fz_target[i] = fz_src[i, nearest_index]
249249
250250
251- cdef class _TestableIndexInterpKernel(InterpKernel):
251+ cdef class PyFuncInterpolator(Interpolator):
252+ cdef bint use_column_prep
253+
254+ def __init__ (self , use_column_prep = True ):
255+ self .use_column_prep = use_column_prep
256+
257+ def column_prep (self , z_target , z_src , fz_src , increasing ):
258+ """
259+ Called each time this interpolator sees a new data array.
260+ This method may be used for validation of a column, or for column
261+ based pre-interpolation calculations (e.g. spline gradients).
262+
263+ Note: This method is not called if :attr:`.call_column_prep` is False.
264+
265+ """
266+ pass
267+
268+ cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
269+ double [:, :] fz_src, bint increasing) nogil except - 1 :
270+ if self .use_column_prep:
271+ with gil:
272+ self .column_prep(z_target, z_src, fz_src, increasing)
273+
274+ def interp_kernel (self , index , z_src , fz_src , level , output_array ):
275+ # Fill the output array with the fz_src data at the given index.
276+ # This is nealy equivalent to nearest neighbour, but doesn't take
277+ # into account which neighbour is nearest.
278+ output_array[:] = fz_src[:, index]
279+
252280 @ cython.boundscheck (False )
253281 @ cython.wraparound (False )
254- cdef long kernel(self , unsigned int index, double [:] z_src, double [:, :] fz_src,
255- double level, double [:] fz_target) nogil except - 1 :
256- # A simple, tesable interpolation, which simply returns the index of interpolation.
257- cdef unsigned int m = fz_src.shape[0 ]
258- cdef unsigned int i
259-
260- for i in range (m):
261- fz_target[i] = index
282+ cdef long kernel(self ,
283+ unsigned int index, double [:] z_src,
284+ double [:, :] fz_src, double level,
285+ double [:] fz_target) nogil except - 1 :
286+ with gil:
287+ self .interp_kernel(index, z_src, fz_src, level, fz_target)
262288
263289
264- cdef class ExtrapKernel (object ):
290+ cdef class Extrapolator (object ):
265291 cdef long kernel(self , int direction,
266292 double [:] z_src, double [:, :] fz_src,
267293 double current_level, double [:] fz_target
@@ -281,15 +307,15 @@ cdef class ExtrapKernel(object):
281307 extrapolated values into.
282308 """
283309 with gil:
284- raise RuntimeError (' InterpKernel subclasses should implement '
310+ raise RuntimeError (' Interpolator subclasses should implement '
285311 ' the kernel function.' )
286312
287313 cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
288314 double [:, :] fz_src, bint increasing) nogil except - 1 :
289315 pass
290316
291317
292- cdef class _NanExtrapKernel(ExtrapKernel ):
318+ cdef class NaNExtrapolator(Extrapolator ):
293319 @ cython.boundscheck (False )
294320 @ cython.wraparound (False )
295321 cdef long kernel(self , int direction, double [:] z_src,
@@ -303,13 +329,13 @@ cdef class _NanExtrapKernel(ExtrapKernel):
303329 fz_target[i] = NAN
304330
305331
306- cdef class _NearestExtrapKernel(ExtrapKernel ):
332+ cdef class NearestNExtrapolator(Extrapolator ):
307333 @ cython.boundscheck (False )
308334 @ cython.wraparound (False )
309335 cdef long kernel(self ,
310- int direction, double [:] z_src,
311- double [:, :] fz_src, double level,
312- double [:] fz_target) nogil except - 1 :
336+ int direction, double [:] z_src,
337+ double [:, :] fz_src, double level,
338+ double [:] fz_target) nogil except - 1 :
313339 """ Nearest-neighbour/edge extrapolation."""
314340 cdef unsigned int m = fz_src.shape[0 ]
315341 cdef unsigned int index, i
@@ -323,7 +349,7 @@ cdef class _NearestExtrapKernel(ExtrapKernel):
323349 fz_target[i] = fz_src[i, index]
324350
325351
326- cdef class _LinearExtrapKernel(ExtrapKernel ):
352+ cdef class LinearExtrapolator(Extrapolator ):
327353 cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
328354 double [:, :] fz_src, bint increasing) nogil except - 1 :
329355 cdef unsigned int n_src_pts = z_src.shape[0 ]
@@ -356,7 +382,7 @@ cdef class _LinearExtrapKernel(ExtrapKernel):
356382 fz_target[i] = fz_src[i, p0] + frac * (fz_src[i, p1] - fz_src[i, p0])
357383
358384
359- cdef class _PythonExtrapKernel(ExtrapKernel ):
385+ cdef class PyFuncExtrapolator(Extrapolator ):
360386 cdef bint use_column_prep
361387
362388 def __init__ (self , use_column_prep = True ):
@@ -392,38 +418,25 @@ cdef class _PythonExtrapKernel(ExtrapKernel):
392418 with gil:
393419 self .extrap_kernel(direction, z_src, fz_src, level, fz_target)
394420
395- cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
396- @ cython.boundscheck (False )
397- @ cython.wraparound (False )
398- cdef long kernel(self ,
399- int direction, double [:] z_src,
400- double [:, :] fz_src, double level,
401- double [:] fz_target) nogil except - 1 :
402- # A simple testable extrapolation which simply returns
403- # -inf if direction == -1 and inf if direction == 1.
404- cdef unsigned int m = fz_src.shape[0 ]
405- cdef double value
406- cdef unsigned int i
407421
408- if direction < 0 :
409- value = - INFINITY
410- else :
411- value = INFINITY
412- for i in range (m):
413- fz_target[i] = value
422+ interp_schemes = { ' nearest ' : NearestNInterpolator,
423+ ' linear ' : LinearInterpolator}
424+
425+ extrap_schemes = { ' nearest ' : NearestNExtrapolator,
426+ ' linear ' : LinearExtrapolator,
427+ ' nan ' : NaNExtrapolator}
414428
415429
416430# Construct interp/extrap constants exposed to the user.
417- INTERPOLATE_LINEAR = _LinearInterpKernel ()
418- INTERPOLATE_NEAREST = _NearestInterpKernel ()
419- EXTRAPOLATE_NAN = _NanExtrapKernel ()
420- EXTRAPOLATE_NEAREST = _NearestExtrapKernel ()
421- EXTRAPOLATE_LINEAR = _LinearExtrapKernel ()
431+ INTERPOLATE_LINEAR = interp_schemes[ ' linear ' ] ()
432+ INTERPOLATE_NEAREST = interp_schemes[ ' nearest ' ] ()
433+ EXTRAPOLATE_NAN = extrap_schemes[ ' nan ' ] ()
434+ EXTRAPOLATE_NEAREST = extrap_schemes[ ' nearest ' ] ()
435+ EXTRAPOLATE_LINEAR = extrap_schemes[ ' linear ' ] ()
422436
423437
424438def interpolate (z_target , z_src , fz_src , axis = - 1 , rising = None ,
425- interpolation = INTERPOLATE_LINEAR,
426- extrapolation = EXTRAPOLATE_NAN):
439+ interpolation = ' linear' , extrapolation = ' nan' ):
427440 """
428441 Interface for optimised 1d interpolation across multiple dimensions.
429442
@@ -466,31 +479,40 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
466479 If rising is None, the first two interpolation coordinate values
467480 will be used to determine the general direction. In most cases,
468481 this is a good option.
469- interpolation: :class:`.InterpKernel ` instance
482+ interpolation: :class:`.Interpolator ` instance or valid scheme name
470483 The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
471484 and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
472485 iterpolation modes. Linear interpolation is the default.
473- extrapolation: :class:`.ExtrapKernel ` instance
486+ extrapolation: :class:`.Extrapolator ` instance or valid scheme name
474487 The core extrapolation operation to use. :attr:`.EXTRAPOLATE_NAN` and
475488 :attr:`.EXTRAPOLATE_NEAREST` are provided for convenient extrapolation
476489 modes. NaN extrapolation is the default.
477490
478491 """
479- interp = _Interpolator(z_target, z_src, fz_src, rising = rising, axis = axis,
480- interpolation = interpolation, extrapolation = extrapolation)
481- return interp.interpolate()
492+ if interpolation in interp_schemes:
493+ interpolation = interp_schemes[interpolation]()
494+ if extrapolation in extrap_schemes:
495+ extrapolation = extrap_schemes[extrapolation]()
496+
497+ interp = _Interpolation(z_target, z_src, fz_src, rising = rising, axis = axis,
498+ interpolation = interpolation,
499+ extrapolation = extrapolation)
500+ if interp.z_target.ndim == 1 :
501+ return interp.interpolate()
502+ else :
503+ return interp.interpolate_z_target_nd()
482504
483505
484- cdef class _Interpolator (object ):
506+ cdef class _Interpolation (object ):
485507 """
486508 Where the magic happens for gridwise_interp. The work of this __init__ is
487509 mostly for putting the input nd arrays into a 3 and 4 dimensional form for
488510 convenient (read: efficient) Cython form. Inline comments should help with
489511 understanding.
490512
491513 """
492- cdef InterpKernel interpolation
493- cdef ExtrapKernel extrapolation
514+ cdef Interpolator interpolation
515+ cdef Extrapolator extrapolation
494516
495517 cdef public np.dtype _target_dtype
496518 cdef int rising
@@ -499,8 +521,8 @@ cdef class _Interpolator(object):
499521
500522 def __init__ (self , z_target , z_src , fz_src , axis = - 1 ,
501523 rising = None ,
502- InterpKernel interpolation = INTERPOLATE_LINEAR,
503- ExtrapKernel extrapolation = EXTRAPOLATE_NAN):
524+ Interpolator interpolation = INTERPOLATE_LINEAR,
525+ Extrapolator extrapolation = EXTRAPOLATE_NAN):
504526 # Cast data to numpy arrays if not already.
505527 z_target = np.array(z_target, dtype = np.float64)
506528 z_src = np.array(z_src, dtype = np.float64)
0 commit comments