@@ -76,8 +76,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
7676@ cython.wraparound (False )
7777cdef long gridwise_interpolation(double [:] z_target, double [:] z_src,
7878 double [:, :] fz_src, bint increasing,
79- interp_kernel interpolation_kernel,
80- extrap_kernel extrapolation_kernel,
79+ InterpKernel interpolation_kernel,
80+ ExtrapKernel extrapolation_kernel,
8181 double [:, :] fz_target) nogil except - 1 :
8282 """
8383 Computes the interpolation of multiple levels of a single column.
@@ -131,6 +131,9 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
131131 fz_target[i, i_target] = NAN
132132 return 0
133133
134+ interpolation_kernel.prepare_column(z_target, z_src, fz_src, increasing)
135+ extrapolation_kernel.prepare_column(z_target, z_src, fz_src, increasing)
136+
134137 if increasing:
135138 z_before = - INFINITY
136139 else :
@@ -181,10 +184,10 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
181184 break
182185
183186 if extrapolating == 0 or sign_after == 0 :
184- interpolation_kernel(i_src, z_src, fz_src, z_current,
187+ interpolation_kernel.kernel (i_src, z_src, fz_src, z_current,
185188 fz_target[:, i_target])
186189 else :
187- extrapolation_kernel(extrapolating, z_src, fz_src, z_current,
190+ extrapolation_kernel.kernel (extrapolating, z_src, fz_src, z_current,
188191 fz_target[:, i_target])
189192
190193 # Move the lower edge of the window forwards to the level we've just computed,
@@ -281,11 +284,6 @@ cdef long linear_extrap(int direction, double[:] z_src,
281284 cdef unsigned int p0, p1, i
282285 cdef double frac
283286
284- if n_src_pts < 2 :
285- with gil:
286- raise ValueError (' Linear extrapolation requires at least '
287- ' 2 source points. Got {}.' .format(n_src_pts))
288-
289287 if direction < 0 :
290288 p0, p1 = 0 , 1
291289 else :
@@ -333,6 +331,10 @@ cdef long _testable_direction_extrap(int direction, double[:] z_src,
333331cdef class InterpKernel(object ):
334332 cdef interp_kernel kernel
335333
334+ cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
335+ double [:, :] fz_src, bint increasing) nogil except - 1 :
336+ pass
337+
336338
337339cdef class _LinearInterpKernel(InterpKernel):
338340 def __init__ (self ):
@@ -352,6 +354,10 @@ cdef class _TestableIndexInterpKernel(InterpKernel):
352354cdef class ExtrapKernel(object ):
353355 cdef extrap_kernel kernel
354356
357+ cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
358+ double [:, :] fz_src, bint increasing) nogil except - 1 :
359+ pass
360+
355361
356362cdef class _NanExtrapKernel(ExtrapKernel):
357363 def __init__ (self ):
@@ -367,6 +373,40 @@ cdef class _LinearExtrapKernel(ExtrapKernel):
367373 def __init__ (self ):
368374 self .kernel = linear_extrap
369375
376+ cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
377+ double [:, :] fz_src, bint increasing) nogil except - 1 :
378+ cdef unsigned int n_src_pts = z_src.shape[0 ]
379+
380+ if n_src_pts < 2 :
381+ with gil:
382+ raise ValueError (' Linear extrapolation requires at least '
383+ ' 2 source points. Got {}.' .format(n_src_pts))
384+
385+
386+ cdef class _PythonExtrapKernel(ExtrapKernel):
387+ cdef bint use_column_prep
388+
389+ def __init__ (self , use_column_prep = True ):
390+ self .kernel = linear_extrap
391+ self .use_column_prep = use_column_prep
392+
393+ def column_prep (self , z_target , z_src , fz_src , increasing ):
394+ """
395+ Called each time this extrapolator sees a new data array.
396+ This method may be used for validation of a column, or for column
397+ based pre-interpolation calculations (e.g. spline gradients).
398+
399+ Note: This method is not called if :attr:`.call_column_prep` is False.
400+
401+ """
402+ pass
403+
404+ cdef bint prepare_column(self , double [:] z_target, double [:] z_src,
405+ double [:, :] fz_src, bint increasing) nogil except - 1 :
406+ if self .use_column_prep:
407+ with gil:
408+ self .column_prep(z_target, z_src, fz_src, increasing)
409+
370410
371411cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
372412 def __init__ (self ):
@@ -449,8 +489,8 @@ cdef class _Interpolator(object):
449489 understanding.
450490
451491 """
452- cdef interp_kernel interpolation
453- cdef extrap_kernel extrapolation
492+ cdef InterpKernel interpolation
493+ cdef ExtrapKernel extrapolation
454494
455495 cdef public np.dtype _target_dtype
456496 cdef int rising
@@ -546,8 +586,17 @@ cdef class _Interpolator(object):
546586
547587 self .rising = bool (rising)
548588
549- self .interpolation = interpolation.kernel
550- self .extrapolation = extrapolation.kernel
589+ # Sometimes we want to add additional constraints on our interpolation
590+ # and extrapolation - for example, linear extrapolation requires there
591+ # to be two coordinates to interpolate from.
592+ if hasattr (self .interpolation, ' validate_data' ):
593+ self .interpolation.validate_data(self )
594+
595+ if hasattr (self .extrapolation, ' validate_data' ):
596+ self .extrapolation.validate_data(self )
597+
598+ self .interpolation = interpolation
599+ self .extrapolation = extrapolation
551600
552601 def interpolate (self ):
553602 # Construct the output array for the interpolation to fill in.
0 commit comments