Skip to content

Commit 502cee1

Browse files
committed
Handled linear extrapolation exception when fewer than 2 interpolation points.
1 parent 1a872d7 commit 502cee1

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

stratify/_vinterp.pyx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,19 @@ cdef long linear_extrap(int direction, double[:] z_src,
277277
double[:] fz_target) nogil except -1:
278278
"""Linear extrapolation using either the first or last 2 values."""
279279
cdef unsigned int m = fz_src.shape[0]
280+
cdef unsigned int n_src_pts = fz_src.shape[1]
280281
cdef unsigned int p0, p1, i
281282
cdef double frac
282283

284+
if n_src_pts < 2:
285+
with gil:
286+
raise ValueError('Linear extrapolation requires at least '
287+
'2 source points. Got {}.'.format(n_src_pts))
288+
283289
if direction < 0:
284290
p0, p1 = 0, 1
285291
else:
286-
p0, p1 = fz_src.shape[1] - 2, fz_src.shape[1] - 1
292+
p0, p1 = n_src_pts - 2, n_src_pts - 1
287293

288294
frac = ((level - z_src[p0]) /
289295
(z_src[p1] - z_src[p0]))

stratify/tests/test_vinterp.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,19 @@ def test_high_precision(self):
145145
assert_array_almost_equal(self.interpolate([1.123456789]),
146146
[11.23456789], decimal=6)
147147

148+
def test_single_point(self):
149+
# Test that a single input point that falls exactly on the target
150+
# level triggers a shortcut that avoids the expectation of >=2 source
151+
# points.
152+
interpolation = stratify.INTERPOLATE_LINEAR
153+
extrapolation = vinterp._TestableDirectionExtrapKernel()
154+
155+
r = stratify.interpolate([2], [2], [20],
156+
interpolation=interpolation,
157+
extrapolation=extrapolation,
158+
rising=True)
159+
self.assertEqual(r, 20)
160+
148161

149162
class Test_INTERPOLATE_NEAREST(unittest.TestCase):
150163
def interpolate(self, x_target):
@@ -156,8 +169,8 @@ def interpolate(self, x_target):
156169

157170
# Use -2 to test negative number support.
158171
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
159-
interpolation=interpolation,
160-
extrapolation=extrapolation)
172+
interpolation=interpolation,
173+
extrapolation=extrapolation)
161174

162175
def test_on_the_mark(self):
163176
assert_array_equal(self.interpolate([0, 1, 2, 3, 4]),
@@ -183,8 +196,8 @@ def interpolate(self, x_target):
183196

184197
# Use -2 to test negative number support.
185198
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
186-
interpolation=interpolation,
187-
extrapolation=extrapolation)
199+
interpolation=interpolation,
200+
extrapolation=extrapolation)
188201

189202
def test_below(self):
190203
assert_array_equal(self.interpolate([-1]), [np.nan])
@@ -203,8 +216,8 @@ def interpolate(self, x_target):
203216

204217
# Use -2 to test negative number support.
205218
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
206-
interpolation=interpolation,
207-
extrapolation=extrapolation)
219+
interpolation=interpolation,
220+
extrapolation=extrapolation)
208221

209222
def test_below(self):
210223
assert_array_equal(self.interpolate([-1]), [0.])
@@ -225,15 +238,27 @@ def interpolate(self, x_target):
225238

226239
# Use -2 to test negative number support.
227240
return stratify.interpolate(np.array(x_target) - 2, x_src - 2, fx_src,
228-
interpolation=interpolation,
229-
extrapolation=extrapolation)
241+
interpolation=interpolation,
242+
extrapolation=extrapolation)
230243

231244
def test_below(self):
232245
assert_array_equal(self.interpolate([-1]), [-10.])
233246

234247
def test_above(self):
235248
assert_array_almost_equal(self.interpolate([15.123]), [151.23])
236249

250+
def test_npts(self):
251+
interpolation = vinterp._TestableIndexInterpKernel()
252+
extrapolation = stratify.EXTRAPOLATE_LINEAR
253+
254+
msg = (r'Linear extrapolation requires at least 2 '
255+
r'source points. Got 1.')
256+
257+
with self.assertRaisesRegexp(ValueError, msg):
258+
stratify.interpolate([1, 3.], [2], [20],
259+
interpolation=interpolation,
260+
extrapolation=extrapolation, rising=True)
261+
237262

238263
class Test__Interpolator(unittest.TestCase):
239264
def test_axis_m1(self):

0 commit comments

Comments
 (0)