Skip to content

Commit 7d48ef1

Browse files
committed
BF+TST - test calc scale force and add reset
Think a bit more about caching of slope, inter and finite_range. Add tests for recalculation of slope, scale. Add reset method and test.
1 parent 2a0a4a8 commit 7d48ef1

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

nibabel/arraywriters.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, array, out_dtype=None, calc_scale=True):
6666
else:
6767
out_dtype = np.dtype(out_dtype)
6868
self._out_dtype = out_dtype
69+
self._finite_range = None
6970
if self.scaling_needed():
7071
raise WriterError("Scaling needed but cannot scale")
7172

@@ -138,11 +139,8 @@ def out_dtype(self):
138139

139140
def finite_range(self):
140141
""" Return (maybe cached) finite range of data array """
141-
try:
142-
return self._finite_range
143-
except AttributeError:
144-
pass
145-
self._finite_range = finite_range(self._array)
142+
if self._finite_range is None:
143+
self._finite_range = finite_range(self._array)
146144
return self._finite_range
147145

148146
def _writing_range(self):
@@ -223,13 +221,17 @@ def __init__(self, array, out_dtype=None, calc_scale=True,
223221
else:
224222
out_dtype = np.dtype(out_dtype)
225223
self._out_dtype = out_dtype
226-
self.needs_scale = self.scaling_needed()
227224
self.scaler_dtype = np.dtype(scaler_dtype)
228-
self.slope = 1.0
229-
self._scale_calced = False
225+
self.reset()
230226
if calc_scale:
231227
self.calc_scale()
232228

229+
def reset(self):
230+
""" Set object to values before any scaling calculation """
231+
self.slope = 1.0
232+
self._finite_range = None
233+
self._scale_calced = False
234+
233235
def _get_slope(self):
234236
return self._slope
235237
def _set_slope(self, val):
@@ -242,10 +244,11 @@ def calc_scale(self, force=False):
242244
# If we've run already, return unless told otherwise
243245
if not force and self._scale_calced:
244246
return
245-
self._scale_calced = True
247+
self.reset()
246248
if not self.scaling_needed():
247249
return
248250
self._do_scaling()
251+
self._scale_calced = True
249252

250253
def to_fileobj(self, fileobj, order='F', nan2zero=True):
251254
""" Write array into `fileobj`
@@ -361,11 +364,15 @@ def __init__(self, array, out_dtype=None, calc_scale=True,
361364
>>> (aw.slope, aw.inter) == (1.0, 128)
362365
True
363366
"""
364-
super(SlopeInterArrayWriter, self).__init__(array, out_dtype, False,
367+
super(SlopeInterArrayWriter, self).__init__(array,
368+
out_dtype,
369+
calc_scale,
365370
scaler_dtype)
371+
372+
def reset(self):
373+
""" Set object to values before any scaling calculation """
374+
super(SlopeInterArrayWriter, self).reset()
366375
self.inter = 0.0
367-
if calc_scale:
368-
self.calc_scale()
369376

370377
def _get_inter(self):
371378
return self._inter

nibabel/tests/test_arraywriters.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,31 @@ def test_calculate_scale():
268268
assert_raises(WriterError, SAW, npa([-1, 1], dtype=np.int8), np.uint8)
269269
# Offset trick can't work when max is out of range
270270
aw = SIAW(npa([-1, 255], dtype=np.int16), np.uint8)
271-
assert_not_equal(get_slope_inter(aw), (1.0, -1.0))
271+
slope_inter = get_slope_inter(aw)
272+
assert_not_equal(slope_inter, (1.0, -1.0))
273+
274+
275+
def test_resets():
276+
# Test reset of values, caching of scales
277+
for klass, inp, outp in ((SlopeInterArrayWriter, (1, 511), (2.0, 1.0)),
278+
(SlopeArrayWriter, (0, 510), (2.0, 0.0))):
279+
arr = np.array(inp)
280+
outp = np.array(outp)
281+
aw = klass(arr, np.uint8)
282+
assert_array_equal(get_slope_inter(aw), outp)
283+
aw.calc_scale() # cached no change
284+
assert_array_equal(get_slope_inter(aw), outp)
285+
aw.calc_scale(force=True) # same data, no change
286+
assert_array_equal(get_slope_inter(aw), outp)
287+
# Change underlying array
288+
aw.array[:] = aw.array * 2
289+
aw.calc_scale() # cached still
290+
assert_array_equal(get_slope_inter(aw), outp)
291+
aw.calc_scale(force=True) # new data, change
292+
assert_array_equal(get_slope_inter(aw), outp * 2)
293+
# Test reset
294+
aw.reset()
295+
assert_array_equal(get_slope_inter(aw), (1.0, 0.0))
272296

273297

274298
def test_no_offset_scale():

0 commit comments

Comments
 (0)