Skip to content

Commit f2cb8de

Browse files
committed
Added a test for 0 gradient on the linear extrapolation.
1 parent 8326d28 commit f2cb8de

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

stratify/_vinterp.pyx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,15 +368,21 @@ cdef class LinearExtrapolator(Extrapolator):
368368
cdef unsigned int m = fz_src.shape[0]
369369
cdef unsigned int n_src_pts = fz_src.shape[1]
370370
cdef unsigned int p0, p1, i
371-
cdef double frac
371+
cdef double frac, z_step
372372

373373
if direction < 0:
374374
p0, p1 = 0, 1
375375
else:
376376
p0, p1 = n_src_pts - 2, n_src_pts - 1
377377

378-
frac = ((level - z_src[p0]) /
379-
(z_src[p1] - z_src[p0]))
378+
# Compute the normalised distance of the target point between p0 and p1
379+
z_step = z_src[p1] - z_src[p0]
380+
if z_step == 0:
381+
# If there is nothing between the last two points then we
382+
# extrapolate using a 0 gradient.
383+
frac = 0
384+
else:
385+
frac = ((level - z_src[p0]) / z_step)
380386

381387
for i in range(m):
382388
fz_target[i] = fz_src[i, p0] + frac * (fz_src[i, p1] - fz_src[i, p0])

stratify/tests/test_vinterp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ def test_on_the_mark(self):
148148
assert_array_equal(self.interpolate([0, 1, 2, 3, 4]),
149149
[0, 10, 20, 30, 40])
150150

151+
def test_zero_gradient(self):
152+
assert_array_equal(
153+
stratify.interpolate([1], [0, 1, 1, 2], [10, 20, 30, 40],
154+
interpolation='linear'),
155+
[20])
156+
151157
def test_inbetween(self):
152158
assert_array_equal(self.interpolate([0.5, 1.25, 2.5, 3.75]),
153159
[5, 12.5, 25, 37.5])
@@ -258,6 +264,12 @@ def test_below(self):
258264
def test_above(self):
259265
assert_array_almost_equal(self.interpolate([15.123]), [151.23])
260266

267+
def test_zero_gradient(self):
268+
assert_array_almost_equal(
269+
stratify.interpolate([2], [0, 0], [1, 1],
270+
extrapolation='linear'),
271+
[1])
272+
261273
def test_npts(self):
262274
interpolation = IndexInterpolator()
263275
extrapolation = stratify.EXTRAPOLATE_LINEAR
@@ -369,7 +381,7 @@ class Test_interpolate(unittest.TestCase):
369381
def test_target_z_3d_axis_0(self):
370382
z_target = z_source = f_source = np.arange(3) * np.ones([4, 2, 3])
371383
result= vinterp.interpolate(z_target, z_source, f_source,
372-
axis=0, extrapolation='linear')
384+
extrapolation='linear')
373385
assert_array_equal(result, f_source)
374386

375387

0 commit comments

Comments
 (0)