Skip to content

Commit 392b418

Browse files
committed
BF+TST - fix max min overflow in int scaling
Add tests for various cases. Fix corner case of negative numbers with 0 as max, with test.
1 parent 0be65d0 commit 392b418

File tree

2 files changed

+79
-7
lines changed

2 files changed

+79
-7
lines changed

nibabel/arraywriters.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ def _do_scaling(self):
278278
if self._out_dtype.kind == 'u':
279279
shared_min, shared_max = shared_range(self.scaler_dtype,
280280
self._out_dtype)
281-
mn, mx = self.finite_range()
282-
if mx < 0 and abs(mn) <= shared_max: # sign flip enough?
281+
if mx <= 0 and abs(mn) <= shared_max: # sign flip enough?
283282
# -1.0 * arr will be in scaler_dtype precision
284283
self.slope = -1.0
285284
return
@@ -299,7 +298,7 @@ def _range_scale(self):
299298
if mn < 0 and mx > 0:
300299
raise WriterError('Cannot scale negative and positive '
301300
'numbers to uint without intercept')
302-
if mx < 0: # All input numbers < 0
301+
if mx <= 0: # All input numbers <= 0
303302
self.slope = mn / shared_max
304303
else: # All input numbers > 0
305304
self.slope = mx / shared_max
@@ -412,11 +411,13 @@ def _do_scaling(self):
412411
if self._out_dtype.kind == 'u':
413412
shared_min, shared_max = shared_range(self.scaler_dtype,
414413
self._out_dtype)
415-
mn, mx = self.finite_range()
416-
if (mx - mn) <= shared_max: # offset enough?
414+
# range may be greater than the largest integer for this type.
415+
# as_int needed to work round numpy 1.4.1 int casting bug
416+
mn2mx = as_int(mx) - as_int(mn)
417+
if mn2mx <= shared_max: # offset enough?
417418
self.inter = mn
418419
return
419-
if mx < 0 and abs(mn) <= shared_max: # sign flip enough?
420+
if mx <= 0 and abs(mn) <= shared_max: # sign flip enough?
420421
# -1.0 * arr will be in scaler_dtype precision
421422
self.slope = -1.0
422423
return

nibabel/tests/test_arraywriters.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(self, array, out_dtype=None, order='F')
3535
WriterError, ScalingError, ArrayWriter,
3636
make_array_writer, get_slope_inter)
3737

38+
from ..casting import int_abs
39+
3840
from ..volumeutils import array_from_file, apply_read_scaling
3941

4042
from numpy.testing import (assert_array_almost_equal,
@@ -171,9 +173,17 @@ def test_calculate_scale():
171173
# Offset handles scaling when it can
172174
aw = SIAW(npa([-2, -1], dtype=np.int8), np.uint8)
173175
assert_equal(get_slope_inter(aw), (1.0, -2.0))
174-
# Sign flip handles this case
176+
# Sign flip handles these cases
175177
aw = SAW(npa([-2, -1], dtype=np.int8), np.uint8)
176178
assert_equal(get_slope_inter(aw), (-1.0, 0.0))
179+
aw = SAW(npa([-2, 0], dtype=np.int8), np.uint8)
180+
assert_equal(get_slope_inter(aw), (-1.0, 0.0))
181+
# But not when min magnitude is too large (scaling mechanism kicks in)
182+
aw = SAW(npa([-510, 0], dtype=np.int16), np.uint8)
183+
assert_equal(get_slope_inter(aw), (-2.0, 0.0))
184+
# Or for floats (attempts to expand across range)
185+
aw = SAW(npa([-2, 0], dtype=np.float32), np.uint8)
186+
assert_not_equal(get_slope_inter(aw), (-1.0, 0.0))
177187
# Case where offset handles scaling
178188
aw = SIAW(npa([-1, 1], dtype=np.int8), np.uint8)
179189
assert_equal(get_slope_inter(aw), (1.0, -1.0))
@@ -187,9 +197,12 @@ def test_calculate_scale():
187197
def test_no_offset_scale():
188198
# Specific tests of no-offset scaling
189199
SAW = SlopeArrayWriter
200+
# Floating point
190201
for data in ((-128, 127),
191202
(-128, 126),
192203
(-128, -127),
204+
(-128, 0),
205+
(-128, -1),
193206
(126, 127),
194207
(-127, 127)):
195208
aw = SAW(np.array(data, dtype=np.float32), np.int8)
@@ -200,6 +213,19 @@ def test_no_offset_scale():
200213
assert_equal(aw.slope, 2)
201214

202215

216+
def test_with_offset_scale():
217+
# Tests of specific cases in slope, inter
218+
SIAW = SlopeInterArrayWriter
219+
aw = SIAW(np.array([0, 127], dtype=np.int8), np.uint8)
220+
assert_equal((aw.slope, aw.inter), (1, 0)) # in range
221+
aw = SIAW(np.array([-1, 126], dtype=np.int8), np.uint8)
222+
assert_equal((aw.slope, aw.inter), (1, -1)) # offset only
223+
aw = SIAW(np.array([-1, 254], dtype=np.int16), np.uint8)
224+
assert_equal((aw.slope, aw.inter), (1, -1)) # offset only
225+
aw = SIAW(np.array([-1, 255], dtype=np.int16), np.uint8)
226+
assert_not_equal((aw.slope, aw.inter), (1, -1)) # Too big for offset only
227+
228+
203229
def test_io_scaling():
204230
# Test scaling works for max, min when going from larger to smaller type,
205231
# and from float to integer.
@@ -364,6 +390,51 @@ def test_float_int_min_max():
364390
assert_true(np.allclose(arr, arr_back_sc))
365391

366392

393+
def test_int_int_min_max():
394+
# Conversion between (u)int and (u)int
395+
eps = np.finfo(np.float64).eps
396+
rtol = 1e-6
397+
for in_dt in IUINT_TYPES:
398+
iinf = np.iinfo(in_dt)
399+
arr = np.array([iinf.min, iinf.max], dtype=in_dt)
400+
for out_dt in IUINT_TYPES:
401+
try:
402+
aw = SlopeInterArrayWriter(arr, out_dt)
403+
except ScalingError:
404+
continue
405+
arr_back_sc = round_trip(aw)
406+
# integer allclose
407+
adiff = int_abs(arr - arr_back_sc)
408+
rdiff = adiff / (arr + eps)
409+
assert_true(np.all(rdiff < rtol))
410+
411+
412+
def test_int_int_slope():
413+
# Conversion between (u)int and (u)int for slopes only
414+
eps = np.finfo(np.float64).eps
415+
rtol = 1e-7
416+
for in_dt in IUINT_TYPES:
417+
iinf = np.iinfo(in_dt)
418+
for out_dt in IUINT_TYPES:
419+
kinds = np.dtype(in_dt).kind + np.dtype(out_dt).kind
420+
if kinds in ('ii', 'uu', 'ui'):
421+
arrs = (np.array([iinf.min, iinf.max], dtype=in_dt),)
422+
elif kinds == 'iu':
423+
arrs = (np.array([iinf.min, 0], dtype=in_dt),
424+
np.array([0, iinf.max], dtype=in_dt))
425+
for arr in arrs:
426+
try:
427+
aw = SlopeArrayWriter(arr, out_dt)
428+
except ScalingError:
429+
continue
430+
assert_false(aw.slope == 0)
431+
arr_back_sc = round_trip(aw)
432+
# integer allclose
433+
adiff = int_abs(arr - arr_back_sc)
434+
rdiff = adiff / (arr + eps)
435+
assert_true(np.all(rdiff < rtol))
436+
437+
367438
def test_float_int_spread():
368439
# Test rounding error for spread of values
369440
powers = np.arange(-10, 10, 0.5)

0 commit comments

Comments
 (0)