Skip to content

Commit 5cfc1f7

Browse files
committed
NF+BF - fix nans as scalar, test nan2zero error
Scalar nan gave an error for fancy indexing.
1 parent fc2d534 commit 5cfc1f7

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

nibabel/casting.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _cached_int_clippers(flt_type, int_type):
4141
return FLT_INT_CLIPS[(flt_type, int_type)]
4242

4343

44-
class RoundingError(Exception):
44+
class CastingError(Exception):
4545
pass
4646

4747

@@ -56,7 +56,7 @@ def nice_round(arr, int_type, nan2zero=True, infmax=False):
5656
Numpy integer type
5757
nan2zero : {True, False}
5858
Whether to convert NaN value to zero. Default is True. If False, and
59-
NaNs are present, raise RoundingError
59+
NaNs are present, raise CastingError
6060
infmax : {False, True}
6161
If True, set np.inf values in `arr` to be `int_type` integer maximum
6262
value, -np.inf as `int_type` integer minimum. If False, merely set infs
@@ -77,23 +77,21 @@ def nice_round(arr, int_type, nan2zero=True, infmax=False):
7777
arr = np.asarray(arr)
7878
flt_type = arr.dtype.type
7979
int_type = np.dtype(int_type).type
80+
# Deal with scalar as input; fancy indexing needs 1D
81+
shape = arr.shape
82+
arr = np.atleast_1d(arr)
8083
mn, mx = _cached_int_clippers(flt_type, int_type)
8184
nans = np.isnan(arr)
8285
have_nans = np.any(nans)
8386
if not nan2zero and have_nans:
84-
raise RoundingError('NaNs in array, nan2zero not True')
87+
raise CastingError('NaNs in array, nan2zero not True')
8588
iarr = np.clip(np.rint(arr), mn, mx).astype(int_type)
8689
if have_nans:
8790
iarr[nans] = 0
8891
if not infmax:
89-
return iarr
90-
# Deal with scalar as input
91-
shape = iarr.shape
92-
iarr = np.atleast_1d(iarr)
93-
arr = np.atleast_1d(arr)
92+
return iarr.reshape(shape)
9493
ii = np.iinfo(int_type)
9594
iarr[arr == np.inf] = ii.max
9695
if ii.min != int(mn):
9796
iarr[arr == -np.inf] = ii.min
98-
iarr.shape = shape
99-
return iarr
97+
return iarr.reshape(shape)

nibabel/tests/test_casting.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import numpy as np
55

6-
from ..casting import nice_round, int_clippers
6+
from ..casting import nice_round, int_clippers, CastingError
77

8-
from numpy.testing import (assert_array_almost_equal,
9-
assert_array_equal)
8+
from numpy.testing import (assert_array_almost_equal, assert_array_equal)
109

1110
from nose.tools import (assert_true, assert_equal, assert_raises)
1211

@@ -64,3 +63,7 @@ def test_casting():
6463
assert_array_equal(farr, np.array(arr, dtype=ft))
6564
# Test scalars work and return scalars
6665
assert_array_equal(nice_round(np.float32(0), np.int16), [0])
66+
# Test scalar nan OK
67+
assert_array_equal(nice_round(np.nan, np.int16), [0])
68+
# Test nans give error if not nan2zero
69+
assert_raises(CastingError, nice_round, np.nan, np.int16, False)

0 commit comments

Comments
 (0)