Skip to content

Commit 2a0a4a8

Browse files
committed
BF+TST - needs_scaling more robust, with tests
``needs_scaling`` routine hit numpy bugs in can_cast function for structured arrays, and float - int conversions in testing ranges for large uints. Found with added tests and fixed.
1 parent 392b418 commit 2a0a4a8

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

nibabel/arraywriters.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,27 +90,40 @@ def scaling_needed(self):
9090
data = self._array
9191
arr_dtype = data.dtype
9292
out_dtype = self._out_dtype
93-
if np.can_cast(arr_dtype, out_dtype):
94-
return False
93+
# There's a bug in np.can_cast (at least up to and including 1.6.1) such
94+
# that any structured output type passes. Check for this first.
9595
if 'V' in (arr_dtype.kind, out_dtype.kind):
96+
if arr_dtype == out_dtype:
97+
return False
9698
raise WriterError('Cannot cast to or from non-numeric types')
99+
if np.can_cast(arr_dtype, out_dtype):
100+
return False
101+
# Direct casting for complex output from any numeric type
97102
if out_dtype.kind == 'c':
98103
return False
99104
if arr_dtype.kind == 'c':
100105
raise WriterError('Cannot cast complex types to non-complex')
106+
# Direct casting for float output from any non-complex numeric type
101107
if out_dtype.kind == 'f':
102108
return False
103109
# Now we need to look at the data for special cases
104110
mn, mx = self.finite_range() # this is cached
105111
if (mn, mx) in ((0, 0), (np.inf, -np.inf)):
106112
# Data all zero, or no data is finite
107113
return False
114+
# Floats -> (u)ints always need scaling
108115
if arr_dtype.kind == 'f':
109116
return True
117+
# (u)int input, (u)int output
110118
assert arr_dtype.kind in 'iu' and out_dtype.kind in 'iu'
111119
info = np.iinfo(out_dtype)
112-
if mn >= info.min and mx <= info.max:
113-
return False
120+
# No scaling needed if data already fits in output type
121+
# But note - we need to convert to ints, to avoid conversion to float
122+
# during comparisons, and therefore int -> float conversions which are
123+
# not exact. Only a problem for uint64 though. We need as_int here to
124+
# work around a numpy 1.4.1 bug in uint conversion
125+
if as_int(mn) >= as_int(info.min) and as_int(mx) <= as_int(info.max):
126+
return False
114127
return True
115128

116129
@property

nibabel/tests/test_arraywriters.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,83 @@ def test_arraywriters():
9797
assert_true(arr_back.flags.c_contiguous)
9898

9999

100+
def test_scaling_needed():
101+
# Structured types return True if dtypes same, raise error otherwise
102+
dt_def = [('f', 'i4')]
103+
arr = np.ones(10, dt_def)
104+
for t in NUMERIC_TYPES:
105+
assert_raises(WriterError, ArrayWriter, arr, t)
106+
narr = np.ones(10, t)
107+
assert_raises(WriterError, ArrayWriter, narr, dt_def)
108+
assert_false(ArrayWriter(arr).scaling_needed())
109+
assert_false(ArrayWriter(arr, dt_def).scaling_needed())
110+
# Any numeric type that can cast, needs no scaling
111+
for in_t in NUMERIC_TYPES:
112+
for out_t in NUMERIC_TYPES:
113+
if np.can_cast(in_t, out_t):
114+
aw = ArrayWriter(np.ones(10, in_t), out_t)
115+
assert_false(aw.scaling_needed())
116+
for in_t in NUMERIC_TYPES:
117+
# Numeric types to complex never need scaling
118+
arr = np.ones(10, in_t)
119+
for out_t in COMPLEX_TYPES:
120+
assert_false(ArrayWriter(arr, out_t).scaling_needed())
121+
# Attempts to scale from complex to anything else fails
122+
for in_t in COMPLEX_TYPES:
123+
for out_t in FLOAT_TYPES + IUINT_TYPES:
124+
arr = np.ones(10, in_t)
125+
assert_raises(WriterError, ArrayWriter, arr, out_t)
126+
# Scaling from anything but complex to floats is OK
127+
for in_t in FLOAT_TYPES + IUINT_TYPES:
128+
arr = np.ones(10, in_t)
129+
for out_t in FLOAT_TYPES:
130+
assert_false(ArrayWriter(arr, out_t).scaling_needed())
131+
# For any other output type, arrays with no data don't need scaling
132+
for in_t in FLOAT_TYPES + IUINT_TYPES:
133+
arr_0 = np.zeros(10, in_t)
134+
arr_e = []
135+
for out_t in IUINT_TYPES:
136+
assert_false(ArrayWriter(arr_0, out_t).scaling_needed())
137+
assert_false(ArrayWriter(arr_e, out_t).scaling_needed())
138+
# Going to (u)ints, non-finite arrays don't need scaling
139+
for in_t in FLOAT_TYPES:
140+
arr_nan = np.zeros(10, in_t) + np.nan
141+
arr_inf = np.zeros(10, in_t) + np.inf
142+
arr_minf = np.zeros(10, in_t) - np.inf
143+
arr_mix = np.array([np.nan, np.inf, -np.inf], dtype=in_t)
144+
for out_t in IUINT_TYPES:
145+
for arr in (arr_nan, arr_inf, arr_minf, arr_mix):
146+
assert_false(ArrayWriter(arr, out_t).scaling_needed())
147+
# Floats as input always need scaling
148+
for in_t in FLOAT_TYPES:
149+
arr = np.ones(10, in_t)
150+
for out_t in IUINT_TYPES:
151+
# We need an arraywriter that will tolerate construction when
152+
# scaling is needed
153+
assert_true(SlopeArrayWriter(arr, out_t).scaling_needed())
154+
# in-range (u)ints don't need scaling
155+
for in_t in IUINT_TYPES:
156+
in_info = np.iinfo(in_t)
157+
in_min, in_max = in_info.min, in_info.max
158+
for out_t in IUINT_TYPES:
159+
out_info = np.iinfo(out_t)
160+
out_min, out_max = out_info.min, out_info.max
161+
if in_min >= out_min and in_max <= out_max:
162+
arr = np.array([in_min, in_max], in_t)
163+
assert_true(np.can_cast(arr.dtype, out_t))
164+
# We've already tested this with can_cast above, but...
165+
assert_false(ArrayWriter(arr, out_t).scaling_needed())
166+
continue
167+
# The output data type does not include the input data range
168+
max_min = max(in_min, out_min) # 0 for input or output uint
169+
min_max = min(in_max, out_max)
170+
arr = np.array([max_min, min_max], in_t)
171+
assert_false(ArrayWriter(arr, out_t).scaling_needed())
172+
assert_true(SlopeInterArrayWriter(arr + 1, out_t).scaling_needed())
173+
if in_t in INT_TYPES:
174+
assert_true(SlopeInterArrayWriter(arr - 1, out_t).scaling_needed())
175+
176+
100177
def test_special_rt():
101178
# Test that zeros; none finite - round trip to zeros
102179
for arr in (np.array([np.inf, np.nan, -np.inf]),

0 commit comments

Comments
 (0)