Skip to content

Commit 397f21b

Browse files
committed
NF: Enable data scaling within the target dtype
1 parent ddf2683 commit 397f21b

File tree

2 files changed

+63
-23
lines changed

2 files changed

+63
-23
lines changed

nibabel/arrayproxy.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -336,36 +336,68 @@ def _get_fileobj(self):
336336
self.file_like, keep_open=False) as opener:
337337
yield opener
338338

339+
def _get_unscaled(self, slicer):
340+
with self._get_fileobj() as fileobj:
341+
return fileslice(fileobj,
342+
slicer,
343+
self._shape,
344+
self._dtype,
345+
self._offset,
346+
order=self.order,
347+
lock=self._lock)
348+
349+
def _get_scaled(self, dtype, slicer):
350+
# Ensure scale factors have dtypes
351+
scl_slope = np.asanyarray(self._slope)
352+
scl_inter = np.asanyarray(self._inter)
353+
if dtype is None:
354+
dtype = max(scl_slope.dtype, scl_inter.dtype)
355+
slope = scl_slope.astype(dtype)
356+
inter = scl_inter.astype(dtype)
357+
# Read array
358+
raw_data = self._get_unscaled(slicer=slicer)
359+
# Upcast as necessary for big slopes, intercepts
360+
return apply_read_scaling(raw_data, slope, inter)
361+
339362
def get_unscaled(self):
340-
""" Read of data from file
363+
""" Read data from file
341364
342365
This is an optional part of the proxy API
343366
"""
344-
with self._get_fileobj() as fileobj, self._lock:
345-
raw_data = array_from_file(self._shape,
346-
self._dtype,
347-
fileobj,
348-
offset=self._offset,
349-
order=self.order,
350-
mmap=self._mmap)
351-
return raw_data
367+
return self._get_unscaled(slicer=())
368+
369+
def get_scaled(self, dtype=None):
370+
""" Read data from file and apply scaling
371+
372+
The dtype of the returned array is the narrowest dtype that can
373+
represent the data without overflow, and is at least as wide as
374+
the dtype parameter.
375+
376+
If dtype is unspecified, it is the wider of the dtypes of the slope
377+
or intercept. This will generally be determined by the parameter
378+
size in the image header, and so should be consistent for a given
379+
image format, but may vary across formats. Notably, these factors
380+
are single-precision (32-bit) floats for NIfTI-1 and double-precision
381+
(64-bit) floats for NIfTI-2.
382+
383+
Parameters
384+
----------
385+
dtype : numpy dtype specifier
386+
A numpy dtype specifier specifying the narrowest acceptable
387+
dtype.
388+
389+
Returns
390+
-------
391+
array
392+
Scaled of image data of data type `dtype`.
393+
"""
394+
return self._get_scaled(dtype=dtype, slicer=())
352395

353396
def __array__(self):
354-
# Read array and scale
355-
raw_data = self.get_unscaled()
356-
return apply_read_scaling(raw_data, self._slope, self._inter)
397+
return self._get_scaled(dtype=None, slicer=())
357398

358399
def __getitem__(self, slicer):
359-
with self._get_fileobj() as fileobj:
360-
raw_data = fileslice(fileobj,
361-
slicer,
362-
self._shape,
363-
self._dtype,
364-
self._offset,
365-
order=self.order,
366-
lock=self._lock)
367-
# Upcast as necessary for big slopes, intercepts
368-
return apply_read_scaling(raw_data, self._slope, self._inter)
400+
return self._get_scaled(dtype=None, slicer=slicer)
369401

370402
def reshape(self, shape):
371403
""" Return an ArrayProxy with a new shape, without modifying data """

nibabel/dataobj_images.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
from .arrayproxy import is_proxy
1314
from .filebasedimages import FileBasedImage
1415
from .keywordonly import kw_only_meth
1516
from .deprecated import deprecate_with_version
@@ -350,7 +351,14 @@ def get_fdata(self, caching='fill', dtype=np.float64):
350351
if self._fdata_cache is not None:
351352
if self._fdata_cache.dtype.type == dtype.type:
352353
return self._fdata_cache
353-
data = np.asanyarray(self._dataobj).astype(dtype, copy=False)
354+
dataobj = self._dataobj
355+
# Attempt to confine data array to dtype during scaling
356+
# On overflow, may still upcast
357+
if is_proxy(dataobj):
358+
dataobj = dataobj.get_scaled(dtype=dtype)
359+
# Always return requested data type
360+
# For array proxies, will only copy on overflow
361+
data = np.asanyarray(dataobj, dtype=dtype)
354362
if caching == 'fill':
355363
self._fdata_cache = data
356364
return data

0 commit comments

Comments
 (0)