Skip to content

Commit fc2d534

Browse files
committed
NF - allow forcing inf values to integer max/min
Flag to force integer values to maximum and minimum of the integer dtype.
1 parent 773cbcf commit fc2d534

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

nibabel/casting.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class RoundingError(Exception):
4545
pass
4646

4747

48-
def nice_round(arr, int_type, nan2zero=True):
48+
def nice_round(arr, int_type, nan2zero=True, infmax=False):
4949
""" Round floating point array `arr` to type `int_type`
5050
5151
Parameters
@@ -57,6 +57,12 @@ def nice_round(arr, int_type, nan2zero=True):
5757
nan2zero : {True, False}
5858
Whether to convert NaN value to zero. Default is True. If False, and
5959
NaNs are present, raise RoundingError
60+
infmax : {False, True}
61+
If True, set np.inf values in `arr` to be `int_type` integer maximum
62+
value, -np.inf as `int_type` integer minimum. If False, merely set infs
63+
to be numbers at or near the maximum / minumum number in `arr` that can be
64+
contained in `int_type`. Therefore False gives faster conversion at the
65+
expense of infs that are further from infinity.
6066
6167
Returns
6268
-------
@@ -67,12 +73,6 @@ def nice_round(arr, int_type, nan2zero=True):
6773
--------
6874
>>> nice_round([np.nan, np.inf, -np.inf, 1.1, 6.6], np.int16)
6975
array([ 0, 32767, -32768, 1, 7], dtype=int16)
70-
71-
Notes
72-
-----
73-
We always set +-inf to be the min / max of the integer type. If you want
74-
something different you'll need to filter them before passing to this
75-
routine.
7676
"""
7777
arr = np.asarray(arr)
7878
flt_type = arr.dtype.type
@@ -85,4 +85,15 @@ def nice_round(arr, int_type, nan2zero=True):
8585
iarr = np.clip(np.rint(arr), mn, mx).astype(int_type)
8686
if have_nans:
8787
iarr[nans] = 0
88+
if not infmax:
89+
return iarr
90+
# Deal with scalar as input
91+
shape = iarr.shape
92+
iarr = np.atleast_1d(iarr)
93+
arr = np.atleast_1d(arr)
94+
ii = np.iinfo(int_type)
95+
iarr[arr == np.inf] = ii.max
96+
if ii.min != int(mn):
97+
iarr[arr == -np.inf] = ii.min
98+
iarr.shape = shape
8899
return iarr

nibabel/tests/test_casting.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,27 @@ def test_casting():
4040
ii = np.iinfo(it)
4141
arr = [ii.min-1, ii.max+1, -np.inf, np.inf, np.nan, 0.2, 10.6]
4242
farr = np.array(arr, dtype=ft)
43-
iarr = nice_round(farr, it)
4443
mn, mx = int_clippers(ft, it)
45-
assert_array_equal(iarr, [mn, mx, mn, mx, 0, 0, 11])
44+
iarr = nice_round(farr, it)
45+
exp_arr = np.array([mn, mx, mn, mx, 0, 0, 11])
46+
assert_array_equal(iarr, exp_arr)
47+
iarr = nice_round(farr, it, infmax=True)
48+
# Float16 can overflow to infs
49+
if farr[0] == -np.inf:
50+
exp_arr[0] = ii.min
51+
if farr[1] == np.inf:
52+
exp_arr[1] = ii.max
53+
exp_arr[2] = ii.min
54+
if exp_arr.dtype.type is np.longdouble:
55+
# longdouble seems to go through float64 on assignment; if
56+
# ii.max is above float64 integer resolution we have go through
57+
# float64 to split up the number and get full precision
58+
f64 = np.float64(ii.max)
59+
exp_arr[3] = np.longdouble(f64) + np.float64(ii.max - int(f64))
60+
else:
61+
exp_arr[3] = ii.max
62+
assert_array_equal(iarr, exp_arr)
63+
# Confirm input array is not modified
64+
assert_array_equal(farr, np.array(arr, dtype=ft))
65+
# Test scalars work and return scalars
66+
assert_array_equal(nice_round(np.float32(0), np.int16), [0])

0 commit comments

Comments
 (0)