Skip to content

Commit a4c978a

Browse files
committed
RF - add None NaN option; some casting cleanups
Add option to pass through NaNs without checking for them (for speed). Some other casting fixes during the tests.
1 parent d7ce13a commit a4c978a

File tree

2 files changed

+59
-33
lines changed

2 files changed

+59
-33
lines changed

nibabel/casting.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ def float_to_int(arr, int_type, nan2zero=True, infmax=False):
2828
Array of floating point type
2929
int_type : object
3030
Numpy integer type
31-
nan2zero : {True, False}
31+
nan2zero : {True, False, None}
3232
Whether to convert NaN value to zero. Default is True. If False, and
33-
NaNs are present, raise CastingError
33+
NaNs are present, raise CastingError. If None, do not check for NaN
34+
values and pass through directly to the ``astype`` casting mechanism.
35+
In this last case, the resulting value is undefined.
3436
infmax : {False, True}
3537
If True, set np.inf values in `arr` to be `int_type` integer maximum
3638
value, -np.inf as `int_type` integer minimum. If False, set +/- infs to
@@ -72,13 +74,16 @@ def float_to_int(arr, int_type, nan2zero=True, infmax=False):
7274
# Deal with scalar as input; fancy indexing needs 1D
7375
shape = arr.shape
7476
arr = np.atleast_1d(arr)
75-
mn, mx = _cached_int_clippers(flt_type, int_type)
76-
nans = np.isnan(arr)
77-
have_nans = np.any(nans)
78-
if not nan2zero and have_nans:
79-
raise CastingError('NaNs in array, nan2zero not True')
77+
mn, mx = int_clippers(flt_type, int_type)
78+
if nan2zero is None:
79+
seen_nans = False
80+
else:
81+
nans = np.isnan(arr)
82+
seen_nans = np.any(nans)
83+
if nan2zero == False and seen_nans:
84+
raise CastingError('NaNs in array, nan2zero is False')
8085
iarr = np.clip(np.rint(arr), mn, mx).astype(int_type)
81-
if have_nans:
86+
if seen_nans:
8287
iarr[nans] = 0
8388
if not infmax:
8489
return iarr.reshape(shape)
@@ -89,6 +94,9 @@ def float_to_int(arr, int_type, nan2zero=True, infmax=False):
8994
return iarr.reshape(shape)
9095

9196

97+
# Cache range values
98+
_SHARED_RANGES = {}
99+
92100
def int_clippers(flt_type, int_type):
93101
""" Min and max in float type that are >=min, <=max in integer type
94102
@@ -98,10 +106,12 @@ def int_clippers(flt_type, int_type):
98106
99107
Parameters
100108
----------
101-
flt_type : object
102-
numpy floating point type
103-
int_type : object
104-
numpy integer type
109+
flt_type : dtype specifier
110+
A dtype specifier referring to a numpy floating point type. For
111+
example, ``f4``, ``np.dtype('f4')``, ``np.float32`` are equivalent.
112+
int_type : dtype specifier
113+
A dtype specifier referring to a numpy integer type. For example,
114+
``i4``, ``np.dtype('i4')``, ``np.int32`` are equivalent
105115
106116
Returns
107117
-------
@@ -111,23 +121,31 @@ def int_clippers(flt_type, int_type):
111121
mx : object
112122
Number of type `flt_type` that is the maximum value in the range of
113123
`int_type`, such that ``mx.astype(int_type)`` <= max of `int_type`
124+
125+
Examples
126+
--------
127+
>>> shared_range(np.float32, np.int32)
128+
(-2147483648.0, 2147483520.0)
129+
>>> shared_range('f4', 'i4')
130+
(-2147483648.0, 2147483520.0)
114131
"""
132+
flt_type = np.dtype(flt_type).type
133+
int_type = np.dtype(int_type).type
134+
key = (flt_type, int_type)
135+
# Used cached value if present
136+
try:
137+
return _SHARED_RANGES[key]
138+
except KeyError:
139+
pass
115140
ii = np.iinfo(int_type)
116-
return floor_exact(ii.min, flt_type), floor_exact(ii.max, flt_type)
117-
118-
119-
# Cache clip values
120-
FLT_INT_CLIPS = {}
121-
122-
def _cached_int_clippers(flt_type, int_type):
123-
if not (flt_type, int_type) in FLT_INT_CLIPS:
124-
FLT_INT_CLIPS[flt_type, int_type] = int_clippers(flt_type, int_type)
125-
return FLT_INT_CLIPS[(flt_type, int_type)]
141+
mn_mx = floor_exact(ii.min, flt_type), floor_exact(ii.max, flt_type)
142+
_SHARED_RANGES[key] = mn_mx
143+
return mn_mx
126144

127-
# ---------------------------------------------------------------------------
128-
# Routines to work out the next lowest representable intger in floating point
145+
# ----------------------------------------------------------------------------
146+
# Routines to work out the next lowest representable integer in floating point
129147
# types.
130-
# ---------------------------------------------------------------------------
148+
# ----------------------------------------------------------------------------
131149

132150
try:
133151
_float16 = np.float16

nibabel/tests/test_casting.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,28 @@ def test_casting():
8181
# We're later going to test if we modify this array
8282
farr = farr_orig.copy()
8383
mn, mx = int_clippers(ft, it)
84-
imn, imx = as_int(mn), as_int(mx)
8584
iarr = float_to_int(farr, it)
86-
exp_arr = [imn, imx, imn, imx, 0, 0, 11]
85+
# Dammit - for long doubles we need to jump through some hoops not
86+
# to round to numbers outside the range
87+
if ft is np.longdouble:
88+
mn = as_int(mn)
89+
mx = as_int(mx)
90+
exp_arr = np.array([mn, mx, mn, mx, 0, 0, 11], dtype=it)
8791
assert_array_equal(iarr, exp_arr)
92+
# Now test infmax version
8893
iarr = float_to_int(farr, it, infmax=True)
94+
im_exp = np.array([mn, mx, ii.min, ii.max, 0, 0, 11], dtype=it)
8995
# Float16 can overflow to infs
9096
if farr[0] == -np.inf:
91-
exp_arr[0] = ii.min
97+
im_exp[0] = ii.min
9298
if farr[1] == np.inf:
93-
exp_arr[1] = ii.max
94-
exp_arr[2] = ii.min
95-
exp_arr[3] = ii.max
96-
# Always comparing integers here, so no issues with int-float
97-
# casting in this comparison
99+
im_exp[1] = ii.max
100+
assert_array_equal(iarr, im_exp)
101+
# NaNs, with nan2zero False, gives error
102+
assert_raises(CastingError, float_to_int, farr, it, False)
103+
# We can pass through NaNs if we really want
104+
exp_arr[arr.index(np.nan)] = ft(np.nan).astype(it)
105+
iarr = float_to_int(farr, it, nan2zero=None)
98106
assert_array_equal(iarr, exp_arr)
99107
# Confirm input array is not modified
100108
nans = np.isnan(farr)

0 commit comments

Comments
 (0)