Skip to content

Commit ea3e607

Browse files
committed
NF: Enable data scaling within the target dtype
1 parent a7acbac commit ea3e607

File tree

2 files changed

+72
-23
lines changed

2 files changed

+72
-23
lines changed

nibabel/arrayproxy.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from .deprecated import deprecate_with_version
3535
from .volumeutils import array_from_file, apply_read_scaling
36-
from .fileslice import fileslice
36+
from .fileslice import fileslice, canonical_slicers
3737
from .keywordonly import kw_only_meth
3838
from . import openers
3939

@@ -336,36 +336,77 @@ def _get_fileobj(self):
336336
self.file_like, keep_open=False) as opener:
337337
yield opener
338338

339-
def get_unscaled(self):
340-
""" Read of data from file
341-
342-
This is an optional part of the proxy API
343-
"""
344-
with self._get_fileobj() as fileobj, self._lock:
345-
raw_data = array_from_file(self._shape,
339+
def _get_unscaled(self, slicer):
340+
if canonical_slicers(slicer, self._shape, False) == \
341+
canonical_slicers((), self._shape, False):
342+
with self._get_fileobj() as fileobj, self._lock:
343+
return array_from_file(self._shape,
346344
self._dtype,
347345
fileobj,
348346
offset=self._offset,
349347
order=self.order,
350348
mmap=self._mmap)
351-
return raw_data
349+
with self._get_fileobj() as fileobj:
350+
return fileslice(fileobj,
351+
slicer,
352+
self._shape,
353+
self._dtype,
354+
self._offset,
355+
order=self.order,
356+
lock=self._lock)
357+
358+
def _get_scaled(self, dtype, slicer):
359+
# Ensure scale factors have dtypes
360+
scl_slope = np.asanyarray(self._slope)
361+
scl_inter = np.asanyarray(self._inter)
362+
if dtype is None:
363+
dtype = scl_slope.dtype
364+
slope = scl_slope.astype(dtype)
365+
inter = scl_inter.astype(dtype)
366+
# Read array
367+
raw_data = self._get_unscaled(slicer=slicer)
368+
# Upcast as necessary for big slopes, intercepts
369+
return apply_read_scaling(raw_data, slope, inter)
370+
371+
def get_unscaled(self):
372+
""" Read data from file
373+
374+
This is an optional part of the proxy API
375+
"""
376+
return self._get_unscaled(slicer=())
377+
378+
def get_scaled(self, dtype=None):
379+
""" Read data from file and apply scaling
380+
381+
The dtype of the returned array is the narrowest dtype that can
382+
represent the data without overflow, and is at least as wide as
383+
the dtype parameter.
384+
385+
If dtype is unspecified, it is the wider of the dtypes of the slope
386+
or intercept. This will generally be determined by the parameter
387+
size in the image header, and so should be consistent for a given
388+
image format, but may vary across formats. Notably, these factors
389+
are single-precision (32-bit) floats for NIfTI-1 and double-precision
390+
(64-bit) floats for NIfTI-2.
391+
392+
Parameters
393+
----------
394+
dtype : numpy dtype specifier
395+
A numpy dtype specifier specifying the narrowest acceptable
396+
dtype.
397+
398+
Returns
399+
-------
400+
array
401+
Scaled of image data of data type `dtype`.
402+
"""
403+
return self._get_scaled(dtype=dtype, slicer=())
352404

353405
def __array__(self):
354-
# Read array and scale
355-
raw_data = self.get_unscaled()
356-
return apply_read_scaling(raw_data, self._slope, self._inter)
406+
return self._get_scaled(dtype=None, slicer=())
357407

358408
def __getitem__(self, slicer):
359-
with self._get_fileobj() as fileobj:
360-
raw_data = fileslice(fileobj,
361-
slicer,
362-
self._shape,
363-
self._dtype,
364-
self._offset,
365-
order=self.order,
366-
lock=self._lock)
367-
# Upcast as necessary for big slopes, intercepts
368-
return apply_read_scaling(raw_data, self._slope, self._inter)
409+
return self._get_scaled(dtype=None, slicer=slicer)
369410

370411
def reshape(self, shape):
371412
""" Return an ArrayProxy with a new shape, without modifying data """

nibabel/dataobj_images.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
from .arrayproxy import is_proxy
1314
from .filebasedimages import FileBasedImage
1415
from .keywordonly import kw_only_meth
1516
from .deprecated import deprecate_with_version
@@ -350,7 +351,14 @@ def get_fdata(self, caching='fill', dtype=np.float64):
350351
if self._fdata_cache is not None:
351352
if self._fdata_cache.dtype.type == dtype.type:
352353
return self._fdata_cache
353-
data = np.asanyarray(self._dataobj).astype(dtype, copy=False)
354+
dataobj = self._dataobj
355+
# Attempt to confine data array to dtype during scaling
356+
# On overflow, may still upcast
357+
if is_proxy(dataobj):
358+
dataobj = dataobj.get_scaled(dtype=dtype)
359+
# Always return requested data type
360+
# For array proxies, will only copy on overflow
361+
data = np.asanyarray(dataobj, dtype=dtype)
354362
if caching == 'fill':
355363
self._fdata_cache = data
356364
return data

0 commit comments

Comments
 (0)