Skip to content

Commit 260068d

Browse files
committed
Abitility to construct python function based interpolations.
1 parent 502cee1 commit 260068d

File tree

2 files changed

+83
-14
lines changed

2 files changed

+83
-14
lines changed

stratify/_vinterp.pyx

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
7676
@cython.wraparound(False)
7777
cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
7878
double[:, :] fz_src, bint increasing,
79-
interp_kernel interpolation_kernel,
80-
extrap_kernel extrapolation_kernel,
79+
InterpKernel interpolation_kernel,
80+
ExtrapKernel extrapolation_kernel,
8181
double [:, :] fz_target) nogil except -1:
8282
"""
8383
Computes the interpolation of multiple levels of a single column.
@@ -131,6 +131,9 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
131131
fz_target[i, i_target] = NAN
132132
return 0
133133

134+
interpolation_kernel.prepare_column(z_target, z_src, fz_src, increasing)
135+
extrapolation_kernel.prepare_column(z_target, z_src, fz_src, increasing)
136+
134137
if increasing:
135138
z_before = -INFINITY
136139
else:
@@ -181,10 +184,10 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
181184
break
182185

183186
if extrapolating == 0 or sign_after == 0:
184-
interpolation_kernel(i_src, z_src, fz_src, z_current,
187+
interpolation_kernel.kernel(i_src, z_src, fz_src, z_current,
185188
fz_target[:, i_target])
186189
else:
187-
extrapolation_kernel(extrapolating, z_src, fz_src, z_current,
190+
extrapolation_kernel.kernel(extrapolating, z_src, fz_src, z_current,
188191
fz_target[:, i_target])
189192

190193
# Move the lower edge of the window forwards to the level we've just computed,
@@ -281,11 +284,6 @@ cdef long linear_extrap(int direction, double[:] z_src,
281284
cdef unsigned int p0, p1, i
282285
cdef double frac
283286

284-
if n_src_pts < 2:
285-
with gil:
286-
raise ValueError('Linear extrapolation requires at least '
287-
'2 source points. Got {}.'.format(n_src_pts))
288-
289287
if direction < 0:
290288
p0, p1 = 0, 1
291289
else:
@@ -333,6 +331,10 @@ cdef long _testable_direction_extrap(int direction, double[:] z_src,
333331
cdef class InterpKernel(object):
334332
cdef interp_kernel kernel
335333

334+
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
335+
double[:, :] fz_src, bint increasing) nogil except -1:
336+
pass
337+
336338

337339
cdef class _LinearInterpKernel(InterpKernel):
338340
def __init__(self):
@@ -352,6 +354,10 @@ cdef class _TestableIndexInterpKernel(InterpKernel):
352354
cdef class ExtrapKernel(object):
353355
cdef extrap_kernel kernel
354356

357+
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
358+
double[:, :] fz_src, bint increasing) nogil except -1:
359+
pass
360+
355361

356362
cdef class _NanExtrapKernel(ExtrapKernel):
357363
def __init__(self):
@@ -367,6 +373,40 @@ cdef class _LinearExtrapKernel(ExtrapKernel):
367373
def __init__(self):
368374
self.kernel = linear_extrap
369375

376+
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
377+
double[:, :] fz_src, bint increasing) nogil except -1:
378+
cdef unsigned int n_src_pts = z_src.shape[0]
379+
380+
if n_src_pts < 2:
381+
with gil:
382+
raise ValueError('Linear extrapolation requires at least '
383+
'2 source points. Got {}.'.format(n_src_pts))
384+
385+
386+
cdef class _PythonExtrapKernel(ExtrapKernel):
387+
cdef bint use_column_prep
388+
389+
def __init__(self, use_column_prep=True):
390+
self.kernel = linear_extrap
391+
self.use_column_prep = use_column_prep
392+
393+
def column_prep(self, z_target, z_src, fz_src, increasing):
394+
"""
395+
Called each time this extrapolator sees a new data array.
396+
This method may be used for validation of a column, or for column
397+
based pre-interpolation calculations (e.g. spline gradients).
398+
399+
Note: This method is not called if :attr:`.call_column_prep` is False.
400+
401+
"""
402+
pass
403+
404+
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
405+
double[:, :] fz_src, bint increasing) nogil except -1:
406+
if self.use_column_prep:
407+
with gil:
408+
self.column_prep(z_target, z_src, fz_src, increasing)
409+
370410

371411
cdef class _TestableDirectionExtrapKernel(ExtrapKernel):
372412
def __init__(self):
@@ -449,8 +489,8 @@ cdef class _Interpolator(object):
449489
understanding.
450490
451491
"""
452-
cdef interp_kernel interpolation
453-
cdef extrap_kernel extrapolation
492+
cdef InterpKernel interpolation
493+
cdef ExtrapKernel extrapolation
454494

455495
cdef public np.dtype _target_dtype
456496
cdef int rising
@@ -546,8 +586,17 @@ cdef class _Interpolator(object):
546586

547587
self.rising = bool(rising)
548588

549-
self.interpolation = interpolation.kernel
550-
self.extrapolation = extrapolation.kernel
589+
# Sometimes we want to add additional constraints on our interpolation
590+
# and extrapolation - for example, linear extrapolation requires there
591+
# to be two coordinates to interpolate from.
592+
if hasattr(self.interpolation, 'validate_data'):
593+
self.interpolation.validate_data(self)
594+
595+
if hasattr(self.extrapolation, 'validate_data'):
596+
self.extrapolation.validate_data(self)
597+
598+
self.interpolation = interpolation
599+
self.extrapolation = extrapolation
551600

552601
def interpolate(self):
553602
# Construct the output array for the interpolation to fill in.

stratify/tests/test_vinterp.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,33 @@ def test_npts(self):
253253

254254
msg = (r'Linear extrapolation requires at least 2 '
255255
r'source points. Got 1.')
256-
257256
with self.assertRaisesRegexp(ValueError, msg):
258257
stratify.interpolate([1, 3.], [2], [20],
259258
interpolation=interpolation,
260259
extrapolation=extrapolation, rising=True)
261260

262261

262+
class Test_custom_extrap_kernel(unittest.TestCase):
263+
class my_kernel(vinterp._PythonExtrapKernel):
264+
def __init__(self, *args, **kwargs):
265+
self.wibble = 'blah'
266+
super(Test_custom_extrap_kernel.my_kernel, self).__init__(*args, **kwargs)
267+
268+
def prepare_fn(self, *args):
269+
print('called me!')
270+
raise ValueError()
271+
272+
def test(self):
273+
interpolation = vinterp._TestableIndexInterpKernel()
274+
extrapolation = Test_custom_extrap_kernel.my_kernel()
275+
276+
stratify.interpolate([1, 3.], [1, 2], [10, 20],
277+
interpolation=interpolation,
278+
extrapolation=extrapolation, rising=True)
279+
print(extrapolation.wibble)
280+
281+
282+
263283
class Test__Interpolator(unittest.TestCase):
264284
def test_axis_m1(self):
265285
data = np.empty([5, 4, 23, 7, 3])

0 commit comments

Comments
 (0)