Skip to content

Commit 96716e1

Browse files
author
Carwyn Pelley
committed
MAINT: Changes based on review
1 parent 2a3e753 commit 96716e1

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

stratify/_vinterp.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ cdef class _Interpolation(object):
548548
# Compute the axis in absolute terms.
549549
fp_axis = (axis + fz_src.ndim) % fz_src.ndim
550550
zp_axis = fp_axis - (fz_src.ndim - z_src.ndim)
551-
if (not 0 <= zp_axis < z_src.ndim or axis > fz_src.ndim):
551+
if (not 0 <= zp_axis < z_src.ndim) or (axis >= fz_src.ndim):
552552
raise ValueError('Axis {} out of range.'.format(axis))
553553

554554
# Ensure that fz_src's shape is a superset of z_src's.

stratify/tests/test_vinterp.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,22 +341,31 @@ def test_inconsistent_shape(self):
341341

342342
def test_axis_out_of_bounds_fz_src_relative(self):
343343
# axis is out of bounds as identified by the absolute axis with z_src.
344-
data = np.empty([5, 4])
345-
zdata = np.empty([5, 4])
344+
data = np.empty((5, 4))
345+
zdata = np.empty((5, 4))
346346
axis = 4
347347
emsg = 'Axis {} out of range'
348348
with self.assertRaisesRegexp(ValueError, emsg.format(axis)):
349349
vinterp._Interpolation([1, 3], data, zdata, axis=axis)
350350

351351
def test_axis_out_of_bounds_z_src_absolute(self):
352352
# axis is out of bounds as identified by the relative axis with fz_src.
353-
data = np.empty([5, 4])
354-
zdata = np.empty([3, 5, 4])
353+
data = np.empty((5, 4))
354+
zdata = np.empty((3, 5, 4))
355355
axis = 0
356356
emsg = 'Axis {} out of range'
357357
with self.assertRaisesRegexp(ValueError, emsg.format(axis)):
358358
vinterp._Interpolation([1, 3], data, zdata, axis=axis)
359359

360+
def test_axis_greater_than_z_src_ndim(self):
361+
# Ensure that axis is not unnecessarily constrained to the dimensions
362+
# of z_src.
363+
data = np.empty((4))
364+
zdata = np.empty((3, 5, 4))
365+
axis = 2
366+
result = vinterp._Interpolation(data.copy(), data, zdata, axis=axis)
367+
self.assertEqual(result.result_shape, (3, 5, 4))
368+
360369
def test_nd_inconsistent_ndims(self):
361370
z_target = np.empty((2, 3, 4))
362371
z_src = np.empty((3, 4))

0 commit comments

Comments
 (0)