Skip to content

Commit 107705a

Browse files
committed
ENH: default read scaling from slope dtype
The apply_read_scaling routine was previously ignoring the dtype of slope and inter when choosing the floating point type to scale with. This change makes sure that the floating point type chosen for scaling has precision at least as high as the input slope dtype.
1 parent b182ec5 commit 107705a

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

nibabel/tests/test_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,16 @@ def test_apply_scaling():
312312
# Upcasting does occur with this routine
313313
assert_equal(apply_read_scaling(i16_arr, big).dtype, np.float64)
314314
assert_equal(apply_read_scaling(i16_arr, big_delta, big).dtype, np.float64)
315-
assert_equal(apply_read_scaling(np.int8(0), -1.0, 0.0).dtype, np.float32)
316-
assert_equal(apply_read_scaling(np.int8(0), 1e38, 0.0).dtype, np.float64)
317-
assert_equal(apply_read_scaling(np.int8(0), -1e38, 0.0).dtype, np.float64)
315+
# If float32 passed, no overflow, float32 returned
316+
assert_equal(apply_read_scaling(np.int8(0), f32(-1.0), f32(0.0)).dtype,
317+
np.float32)
318+
# float64 passed, float64 returned
319+
assert_equal(apply_read_scaling(np.int8(0), -1.0, 0.0).dtype, np.float64)
320+
# float32 passed, overflow, float64 returned
321+
assert_equal(apply_read_scaling(np.int8(0), f32(1e38), f32(0.0)).dtype,
322+
np.float64)
323+
assert_equal(apply_read_scaling(np.int8(0), f32(-1e38), f32(0.0)).dtype,
324+
np.float64)
318325

319326

320327
def test_int_scinter():

nibabel/volumeutils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,9 @@ def apply_read_scaling(arr, slope = 1.0, inter = 0.0):
747747
# Force float / float upcasting by promoting to arrays
748748
arr, slope, inter = [np.atleast_1d(v) for v in arr, slope, inter]
749749
if arr.dtype.kind in 'iu':
750-
ftype = int_scinter_ftype(arr.dtype, slope, inter)
750+
# Find floating point type for which scaling does not overflow, starting
751+
# at given type
752+
ftype = int_scinter_ftype(arr.dtype, slope, inter, slope.dtype.type)
751753
slope = slope.astype(ftype)
752754
inter = inter.astype(ftype)
753755
if slope != 1.0:

0 commit comments

Comments
 (0)