Skip to content

Commit ce15aae

Browse files
committed
WIP: still failing tests
1 parent 510faa9 commit ce15aae

File tree

2 files changed

+85
-71
lines changed

2 files changed

+85
-71
lines changed

src/stratify/_vinterp.pyx

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ cdef inline int relative_sign(double z, double z_base) nogil:
5353
@cython.boundscheck(False)
5454
@cython.wraparound(False)
5555
cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
56-
double[:, :] fz_src, bint increasing,
56+
double[:, :] fz_src, bint rising,
57+
bint aligned,
5758
Interpolator interpolation,
5859
Extrapolator extrapolation,
5960
double [:, :] fz_target) nogil except -1:
@@ -65,7 +66,8 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
6566
z_target - the levels to interpolate the source data ``fz_src`` to.
6667
z_src - the levels that the source data ``fz_src`` is interpolated from.
6768
fz_src - the source data to be interpolated.
68-
increasing - true when increasing Z index generally implies increasing Z values
69+
rising - true when rising Z index generally implies rising Z values
70+
aligned - true when both src and tgt increase/decrease in the same direction
6971
interpolation - the inner interpolation functionality. See the definition of
7072
Interpolator.
7173
extrapolation - the inner extrapolation functionality. See the definition of
@@ -91,7 +93,7 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
9193
cdef unsigned int i_src, i_target, n_src, n_target, i, m
9294
cdef bint all_nans = True
9395
cdef double z_before, z_current, z_after, z_last
94-
cdef int sign_after, sign_before, extrapolating
96+
cdef int sign_after, sign_before, extrapolating, z_final
9597

9698
n_src = z_src.shape[0]
9799
n_target = z_target.shape[0]
@@ -110,13 +112,12 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
110112
fz_target[i, i_target] = NAN
111113
return 0
112114

113-
interpolation.prepare_column(z_target, z_src, fz_src, increasing)
114-
extrapolation.prepare_column(z_target, z_src, fz_src, increasing)
115+
interpolation.prepare_column(z_target, z_src, fz_src, rising)
116+
extrapolation.prepare_column(z_target, z_src, fz_src, rising)
117+
with gil:
118+
z_src = np.asarray(z_src)
115119

116-
if increasing:
117-
z_before = -INFINITY
118-
else:
119-
z_before = INFINITY
120+
z_before = -INFINITY if rising else INFINITY
120121

121122
z_last = -z_before
122123

@@ -125,7 +126,11 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
125126
# first window value (typically -inf, but may be +inf) and the first z_src.
126127
# This search window will be moved along until a crossing is detected, at
127128
# which point we will do an interpolation.
128-
z_after = z_src[0]
129+
with gil:
130+
z_final = z_src.size - 1
131+
132+
133+
z_after = z_src[0] if aligned else z_src[z_final]
129134

130135
# We start in extrapolation mode. This will be turned off as soon as we
131136
# start increasing i_src.
@@ -151,7 +156,12 @@ cdef long gridwise_interpolation(double[:] z_target, double[:] z_src,
151156
i_src += 1
152157
if i_src < n_src:
153158
extrapolating = 0
154-
z_after = z_src[i_src]
159+
with gil:
160+
if aligned:
161+
z_after = z_src[i_src]
162+
else:
163+
dummy = z_src.size - (i_src + 1)
164+
z_after = z_src[dummy]
155165
if isnan(z_after):
156166
with gil:
157167
raise ValueError('The source coordinate may not contain NaN values.')
@@ -201,7 +211,7 @@ cdef class Interpolator(object):
201211
'the kernel function.')
202212

203213
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
204-
double[:, :] fz_src, bint increasing) nogil except -1:
214+
double[:, :] fz_src, bint rising) nogil except -1:
205215
# Called before all levels are interpolated.
206216
pass
207217

@@ -262,7 +272,7 @@ cdef class PyFuncInterpolator(Interpolator):
262272
def __init__(self, use_column_prep=True):
263273
self.use_column_prep = use_column_prep
264274

265-
def column_prep(self, z_target, z_src, fz_src, increasing):
275+
def column_prep(self, z_target, z_src, fz_src, rising):
266276
"""
267277
Called each time this interpolator sees a new data array.
268278
This method may be used for validation of a column, or for column
@@ -274,10 +284,10 @@ cdef class PyFuncInterpolator(Interpolator):
274284
pass
275285

276286
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
277-
double[:, :] fz_src, bint increasing) nogil except -1:
287+
double[:, :] fz_src, bint rising) nogil except -1:
278288
if self.use_column_prep:
279289
with gil:
280-
self.column_prep(z_target, z_src, fz_src, increasing)
290+
self.column_prep(z_target, z_src, fz_src, rising)
281291

282292
def interp_kernel(self, index, z_src, fz_src, level, output_array):
283293
# Fill the output array with the fz_src data at the given index.
@@ -319,7 +329,7 @@ cdef class Extrapolator(object):
319329
'the kernel function.')
320330

321331
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
322-
double[:, :] fz_src, bint increasing) nogil except -1:
332+
double[:, :] fz_src, bint rising) nogil except -1:
323333
pass
324334

325335

@@ -359,7 +369,7 @@ cdef class NearestNExtrapolator(Extrapolator):
359369

360370
cdef class LinearExtrapolator(Extrapolator):
361371
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
362-
double[:, :] fz_src, bint increasing) nogil except -1:
372+
double[:, :] fz_src, bint rising) nogil except -1:
363373
cdef unsigned int n_src_pts = z_src.shape[0]
364374

365375
if n_src_pts < 2:
@@ -402,7 +412,7 @@ cdef class PyFuncExtrapolator(Extrapolator):
402412
def __init__(self, use_column_prep=True):
403413
self.use_column_prep = use_column_prep
404414

405-
def column_prep(self, z_target, z_src, fz_src, increasing):
415+
def column_prep(self, z_target, z_src, fz_src, rising):
406416
"""
407417
Called each time this extrapolator sees a new data array.
408418
This method may be used for validation of a column, or for column
@@ -414,10 +424,10 @@ cdef class PyFuncExtrapolator(Extrapolator):
414424
pass
415425

416426
cdef bint prepare_column(self, double[:] z_target, double[:] z_src,
417-
double[:, :] fz_src, bint increasing) nogil except -1:
427+
double[:, :] fz_src, bint rising) nogil except -1:
418428
if self.use_column_prep:
419429
with gil:
420-
self.column_prep(z_target, z_src, fz_src, increasing)
430+
self.column_prep(z_target, z_src, fz_src, rising)
421431

422432
def extrap_kernel(self, direction, z_src, fz_src, level, output_array):
423433
# Fill the output array with nans.
@@ -449,7 +459,7 @@ EXTRAPOLATE_NEAREST = extrap_schemes['nearest']()
449459
EXTRAPOLATE_LINEAR = extrap_schemes['linear']()
450460

451461

452-
def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
462+
def interpolate(z_target, z_src, fz_src, rising=None, axis=-1,
453463
interpolation='linear', extrapolation='nan'):
454464
"""
455465
Interface for optimised 1d interpolation across multiple dimensions.
@@ -486,16 +496,6 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
486496
the same as the shape of ``z_src``.
487497
axis: int (default -1)
488498
The ``fz_src`` axis to perform the interpolation over.
489-
rising: bool (default None)
490-
Whether the values of the source's interpolation coordinate values
491-
are generally rising or generally falling. For example, values of
492-
pressure levels will be generally falling as the z coordinate
493-
increases.
494-
This will determine whether extrapolation needs to occur for
495-
``z_target`` below the first and above the last ``z_src``.
496-
If rising is None, the first two interpolation coordinate values
497-
will be used to determine the general direction. In most cases,
498-
this is a good option.
499499
interpolation: :class:`.Interpolator` instance or valid scheme name
500500
The core interpolation operation to use. :attr:`.INTERPOLATE_LINEAR`
501501
and :attr:`_INTERPOLATE_NEAREST` are provided for convenient
@@ -509,7 +509,6 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
509509
func = functools.partial(
510510
_interpolate,
511511
axis=axis,
512-
rising=rising,
513512
interpolation=interpolation,
514513
extrapolation=extrapolation
515514
)
@@ -564,14 +563,14 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
564563
meta=np.array((), dtype=fz_src.dtype))
565564

566565

567-
def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
566+
def _interpolate(z_target, z_src, fz_src, axis=-1,
568567
interpolation='linear', extrapolation='nan'):
569568
if interpolation in interp_schemes:
570569
interpolation = interp_schemes[interpolation]()
571570
if extrapolation in extrap_schemes:
572571
extrapolation = extrap_schemes[extrapolation]()
573572

574-
interp = _Interpolation(z_target, z_src, fz_src, rising=rising, axis=axis,
573+
interp = _Interpolation(z_target, z_src, fz_src, axis=axis,
575574
interpolation=interpolation,
576575
extrapolation=extrapolation)
577576
if interp.z_target.ndim == 1:
@@ -583,16 +582,14 @@ def _interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
583582
cdef class _Interpolation(object):
584583
"""
585584
Where the magic happens for gridwise_interp. The work of this __init__ is
586-
mostly for putting the input nd arrays into a 3 and 4 dimensional form for
587-
convenient (read: efficient) Cython form. Inline comments should help with
588-
understanding.
585+
mostly for putting the input nd.
589586
590587
"""
591588
cdef Interpolator interpolation
592589
cdef Extrapolator extrapolation
593590

594591
cdef public np.dtype _target_dtype
595-
cdef int rising
592+
cdef rising, aligned
596593
cdef public z_target, orig_shape, axis, _zp_reshaped, _fp_reshaped
597594
cdef public _result_working_shape, result_shape
598595

@@ -692,17 +689,27 @@ cdef class _Interpolation(object):
692689
#: The shape of the interpolated data.
693690
self.result_shape = tuple(result_shape)
694691

695-
if rising is None:
696-
if z_src.shape[zp_axis] < 2:
697-
raise ValueError('The rising keyword must be defined when '
698-
'the size of the source array is <2 in '
699-
'the interpolation axis.')
700-
z_src_indexer = [0] * z_src.ndim
701-
z_src_indexer[zp_axis] = slice(0, 2)
702-
first_two = z_src[tuple(z_src_indexer)]
703-
rising = first_two[0] <= first_two[1]
692+
if z_src.shape[zp_axis] < 2:
693+
raise ValueError('The rising keyword must be defined when '
694+
'the size of the source array is <2 in '
695+
'the interpolation axis.')
696+
704697

705-
self.rising = bool(rising)
698+
z_src_indexer = [0] * z_src.ndim
699+
z_src_indexer[zp_axis] = slice(0, 2)
700+
src_first_two = z_src[tuple(z_src_indexer)]
701+
src_rising = src_first_two[0] <= src_first_two[1]
702+
src_rise = bool(src_rising)
703+
704+
z_tgt_indexer = [0] * z_target.ndim
705+
z_tgt_indexer[zp_axis] = slice(0, 2)
706+
tgt_first_two = z_target[tuple(z_tgt_indexer)]
707+
tgt_rising = tgt_first_two[0] <= tgt_first_two[1]
708+
tgt_rise = bool(tgt_rising)
709+
710+
711+
self.rising = bool(tgt_rising)
712+
self.aligned = src_rise == tgt_rise
706713

707714
# Sometimes we want to add additional constraints on our interpolation
708715
# and extrapolation - for example, linear extrapolation requires there
@@ -733,13 +740,17 @@ cdef class _Interpolation(object):
733740
# Construct a memory view of the fz_target array.
734741
cdef double[:, :, :, :] fz_target_view = fz_target
735742

743+
cdef int rising = self.rising
744+
cdef int aligned = self.aligned
745+
736746
# Release the GIL and do the for loop over the left-hand, and
737747
# right-hand dimensions. The loop optimised for row-major data (C).
738748
with nogil:
739749
for j in range(nj):
740750
for i in range(ni):
741751
gridwise_interpolation(z_target, z_src[i, :, j], fz_src[:, i, :, j],
742-
self.rising,
752+
rising,
753+
aligned,
743754
self.interpolation,
744755
self.extrapolation,
745756
fz_target_view[:, i, :, j])
@@ -755,6 +766,8 @@ cdef class _Interpolation(object):
755766
fz_target = np.empty(self._result_working_shape, dtype=np.float64)
756767

757768
cdef unsigned int i, j, ni, nj
769+
cdef int rising = self.rising
770+
cdef int aligned = self.aligned
758771

759772
ni = fz_target.shape[1]
760773
nj = fz_target.shape[3]
@@ -775,7 +788,8 @@ cdef class _Interpolation(object):
775788
for j in range(nj):
776789
for i in range(ni):
777790
gridwise_interpolation(z_target[i, :, j], z_src[i, :, j], fz_src[:, i, :, j],
778-
self.rising,
791+
rising,
792+
aligned,
779793
self.interpolation,
780794
self.extrapolation,
781795
fz_target_view[:, i, :, j])

0 commit comments

Comments
 (0)