Skip to content

Commit 09cfef1

Browse files
committed
BF: properly compare NaNs and test for >2 arguments given to are_values_different
1 parent e8d3a4f commit 09cfef1

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

nibabel/cmdline/diff.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,41 @@ def get_opt_parser():
4545

4646

4747
def are_values_different(*values):
48-
"""Generically compares values, returns true if different"""
49-
value0 = values[0]
50-
values = values[1:] # to ensure that the first value isn't compared with itself
48+
"""Generically compare values, return True if different
5149
52-
for value in values:
53-
try: # we sometimes don't want NaN values
54-
if np.any(np.isnan(value0)) and np.any(np.isnan(value)): # if they're both NaN
55-
break
56-
elif np.any(np.isnan(value0)) or np.any(np.isnan(value)): # if only 1 is NaN
57-
return True
50+
Note that comparison is targetting reporting of comparison of the headers
51+
so has following specifics:
52+
- even a difference in data types is considered a difference, i.e. 1 != 1.0
53+
- NaNs are considered to be the "same", although generally NaN != NaN
54+
"""
55+
value0 = values[0]
5856

59-
except TypeError:
60-
pass
57+
# to not recompute over again
58+
if isinstance(value0, np.ndarray):
59+
value0_nans = np.isnan(value0)
60+
if not np.any(value0_nans):
61+
value0_nans = None
6162

63+
for value in values[1:]:
6264
if type(value0) != type(value): # if types are different, then we consider them different
6365
return True
6466
elif isinstance(value0, np.ndarray):
65-
return np.any(value0 != value)
66-
67+
if value0.dtype != value.dtype or \
68+
value0.shape != value.shape:
69+
return True
70+
# there might be NaNs and they need special treatment
71+
if value0_nans is not None:
72+
value_nans = np.isnan(value)
73+
if np.any(value0_nans != value_nans):
74+
return True
75+
if np.any(value0[np.logical_not(value0_nans)]
76+
!= value[np.logical_not(value0_nans)]):
77+
return True
78+
elif np.any(value0 != value):
79+
return True
80+
elif value0 is np.NaN:
81+
if value is not np.NaN:
82+
return True
6783
elif value0 != value:
6884
return True
6985

nibabel/tests/test_diff.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,30 @@ def test_diff_values_mixed():
4444

4545

4646
def test_diff_values_array():
47-
a_int = np.array([1, 2])
47+
from numpy import NaN, array, inf
48+
a_int = array([1, 2])
4849
a_float = a_int.astype(float)
4950

50-
#assert are_values_different(a_int, a_float)
51+
assert are_values_different(a_int, a_float)
52+
assert are_values_different(a_int, a_int, a_float)
5153
assert are_values_different(np.arange(3), np.arange(1, 4))
54+
assert are_values_different(np.arange(3), np.arange(4))
55+
assert are_values_different(np.arange(4), np.arange(4).reshape((2, 2)))
56+
# no broadcasting should kick in - shape difference
57+
assert are_values_different(array([1]), array([1, 1]))
5258
assert not are_values_different(a_int, a_int)
5359
assert not are_values_different(a_float, a_float)
60+
61+
# NaNs - we consider them "the same" for the purpose of these comparisons
62+
assert not are_values_different(NaN, NaN)
63+
assert not are_values_different(NaN, NaN, NaN)
64+
assert are_values_different(NaN, NaN, 1)
65+
assert are_values_different(1, NaN, NaN)
66+
assert not are_values_different(array([NaN, NaN]), array([NaN, NaN]))
67+
assert not are_values_different(array([NaN, NaN]), array([NaN, NaN]), array([NaN, NaN]))
68+
assert not are_values_different(array([NaN, 1]), array([NaN, 1]))
69+
assert are_values_different(array([NaN, NaN]), array([NaN, 1]))
70+
assert are_values_different(array([0, NaN]), array([NaN, 0]))
71+
# and some inf should not be a problem
72+
assert not are_values_different(array([0, inf]), array([0, inf]))
73+
assert are_values_different(array([0, inf]), array([inf, 0]))

0 commit comments

Comments
 (0)