Skip to content

Commit 066c32a

Browse files
committed
FIX: Better type promotion
1 parent ac52ad8 commit 066c32a

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

nibabel/arrayproxy.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,17 @@ def _get_scaled(self, dtype, slicer):
359359
# Ensure scale factors have dtypes
360360
scl_slope = np.asanyarray(self._slope)
361361
scl_inter = np.asanyarray(self._inter)
362-
if dtype is None:
363-
dtype = scl_slope.dtype
364-
slope = scl_slope.astype(dtype)
365-
inter = scl_inter.astype(dtype)
362+
use_dtype = scl_slope.dtype if dtype is None else dtype
363+
slope = scl_slope.astype(use_dtype)
364+
inter = scl_inter.astype(use_dtype)
366365
# Read array
367366
raw_data = self._get_unscaled(slicer=slicer)
368367
# Upcast as necessary for big slopes, intercepts
369-
return apply_read_scaling(raw_data, slope, inter)
368+
scaled = apply_read_scaling(raw_data, slope, inter)
369+
del raw_data
370+
if dtype is not None:
371+
scaled = scaled.astype(np.promote_types(scaled.dtype, dtype), copy=False)
372+
return scaled
370373

371374
def get_unscaled(self):
372375
""" Read data from file

nibabel/parrec.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -658,20 +658,29 @@ def _get_unscaled(self, slicer):
658658
def _get_scaled(self, dtype, slicer):
659659
raw_data = self._get_unscaled(slicer)
660660
if self._slice_scaling is None:
661-
if dtype is None or raw_data.dtype >= np.dtype(dtype):
661+
if dtype is None:
662662
return raw_data
663-
return np.asanyarray(raw_data, dtype=dtype)
663+
final_type = np.promote_types(raw_data.dtype, dtype)
664+
return raw_data.astype(final_type, copy=False)
664665

665666
# Broadcast scaling to shape of original data
666667
slopes, inters = self._slice_scaling
667668
fake_data = strided_scalar(self._shape)
668669
_, slopes, inters = np.broadcast_arrays(fake_data, slopes, inters)
669670

670-
if dtype is None:
671-
dtype = max(slopes.dtype, inters.dtype)
671+
final_type = np.result_type(raw_data, slopes, inters)
672+
if dtype is not None:
673+
final_type = np.promote_types(final_type, dtype)
674+
675+
slopes = slopes.astype(final_type)
676+
inters = inters.astype(final_type)
677+
678+
if slicer is not None:
679+
slopes = slopes[slicer]
680+
inters = inters[slicer]
672681

673682
# Slice scaling to give output shape
674-
return raw_data * slopes.astype(dtype)[slicer] + inters.astype(dtype)[slicer]
683+
return raw_data * slopes + inters
675684

676685

677686
def get_scaled(self, dtype=None):

0 commit comments

Comments
 (0)