Skip to content

Commit ae75e20

Browse files
committed
NF - add able_int_type function
Function selects integer type capable of representing all passed values Numpy doesn't itself try and minimize the integer type so able_int_type has to do this the slow way. Confirm able_int_type corresponds to numpy casting.
1 parent 4709b63 commit ae75e20

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

nibabel/casting.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
""" Utilties for casting floats to integers
1+
""" Utilties for casting numpy values in various ways
2+
3+
Most routines work round some numpy oddities in floating point precision and
4+
casting. Others work round numpy casting to and from python ints
25
"""
36

47
from platform import processor
@@ -483,3 +486,40 @@ def ok_floats():
483486

484487

485488
OK_FLOATS = ok_floats()
489+
490+
491+
def able_int_type(values):
492+
""" Find the smallest integer numpy type to contain sequence `values`
493+
494+
Prefers uint to int if minimum is >= 0
495+
496+
Parameters
497+
----------
498+
values : sequence
499+
sequence of integer values
500+
501+
Returns
502+
-------
503+
itype : None or numpy type
504+
numpy integer type or None if no integer type holds all `values`
505+
506+
Examples
507+
--------
508+
>>> able_int_type([0, 1]) == np.uint8
509+
True
510+
>>> able_int_type([-1, 1]) == np.int8
511+
True
512+
"""
513+
if any([v % 1 for v in values]):
514+
return None
515+
mn = min(values)
516+
mx = max(values)
517+
if mn >= 0:
518+
for ityp in np.sctypes['uint']:
519+
if mx <= np.iinfo(ityp).max:
520+
return ityp
521+
for ityp in np.sctypes['int']:
522+
info = np.iinfo(ityp)
523+
if mn >= info.min and mx <= info.max:
524+
return ityp
525+
return None

nibabel/tests/test_casting.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,76 @@ def test_int_abs():
136136
assert_equal(int_abs(mx), mx)
137137
assert_equal(int_abs(mn), e_mn)
138138
assert_array_equal(int_abs(in_arr), [e_mn, mx])
139+
140+
141+
def test_able_int_type():
142+
# The integer type cabable of containing values
143+
for vals, exp_out in (
144+
([0, 1], np.uint8),
145+
([0, 255], np.uint8),
146+
([-1, 1], np.int8),
147+
([0, 256], np.uint16),
148+
([-1, 128], np.int16),
149+
([0.1, 1], None),
150+
([0, 2**16], np.uint32),
151+
([-1, 2**15], np.int32),
152+
([0, 2**32], np.uint64),
153+
([-1, 2**31], np.int64),
154+
([-1, 2**64-1], None),
155+
([0, 2**64-1], np.uint64),
156+
([0, 2**64], None)):
157+
assert_equal(able_int_type(vals), exp_out)
158+
159+
160+
def test_able_casting():
161+
# Check the able_int_type function guesses numpy out type
162+
types = np.sctypes['int'] + np.sctypes['uint']
163+
for in_type in types:
164+
in_info = np.iinfo(in_type)
165+
in_mn, in_mx = in_info.min, in_info.max
166+
A = np.zeros((1,), dtype=in_type)
167+
for out_type in types:
168+
out_info = np.iinfo(out_type)
169+
out_mn, out_mx = out_info.min, out_info.max
170+
B = np.zeros((1,), dtype=out_type)
171+
ApBt = (A + B).dtype.type
172+
able_type = able_int_type([in_mn, in_mx, out_mn, out_mx])
173+
if able_type is None:
174+
assert_equal(ApBt, np.float64)
175+
continue
176+
# Use str for comparison to avoid int32/64 vs intp comparison
177+
# failures
178+
assert_equal(np.dtype(ApBt).str, np.dtype(able_type).str)
179+
180+
181+
def test_best_float():
182+
# Finds the most capable floating point type
183+
# The only time this isn't np.longdouble is when np.longdouble has float64
184+
# precision.
185+
best = best_float()
186+
end_of_ints = np.float64(2**53)
187+
# float64 has continuous integers up to 2**53
188+
assert_equal(end_of_ints, end_of_ints + 1)
189+
# longdouble may have more, but not on 32 bit windows, at least
190+
end_of_ints = np.longdouble(2**53)
191+
if end_of_ints == (end_of_ints + 1): # off continuous integers
192+
assert_equal(best, np.float64)
193+
else:
194+
assert_equal(best, np.longdouble)
195+
196+
197+
def test_eps():
198+
assert_equal(eps(), np.finfo(np.float64).eps)
199+
assert_equal(eps(1.0), np.finfo(np.float64).eps)
200+
assert_equal(eps(np.float32(1.0)), np.finfo(np.float32).eps)
201+
assert_equal(eps(np.float32(1.999)), np.finfo(np.float32).eps)
202+
# Integers always return 1
203+
assert_equal(eps(1), 1)
204+
assert_equal(eps(2**63-1), 1)
205+
# negative / positive same
206+
assert_equal(eps(-1), 1)
207+
assert_equal(eps(7.999), eps(4.0))
208+
assert_equal(eps(-7.999), eps(4.0))
209+
assert_equal(eps(np.float64(2**54-2)), 2)
210+
assert_equal(eps(np.float64(2**54)), 4)
211+
assert_equal(eps(np.float64(2**54)), 4)

0 commit comments

Comments
 (0)