Skip to content

Commit ae74339

Browse files
committed
i think im going to cry. code cleaned and made more pythonic
1 parent 7989563 commit ae74339

File tree

2 files changed

+56
-177
lines changed

2 files changed

+56
-177
lines changed

nibabel/cmdline/diff.py

Lines changed: 35 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from optparse import OptionParser, Option
1919

2020
import numpy as np
21+
import functools.partial
2122

2223
import nibabel as nib
2324
import nibabel.cmdline.utils
@@ -43,177 +44,54 @@ def get_opt_parser():
4344
return p
4445

4546

46-
def diff_values(first_item, second_item):
47-
"""Generically compares two values, returns true if different"""
48-
if np.any(first_item != second_item): # comparing items that are instances of class np.ndarray
49-
return True
47+
def are_values_different(*values):
48+
"""Generically compares values, returns true if different"""
49+
value0 = values[0]
50+
values = values[1:] # to ensure that the first value isn't compared with itself
5051

51-
elif type(first_item) != type(second_item): # comparing items that differ in data type
52-
return True
52+
for value in values:
53+
try: # we don't want NaN values
54+
if np.any(np.isnan(value0)) or np.any(np.isnan(value)):
55+
break
5356

54-
else: # all other use cases
55-
return first_item != second_item
57+
except TypeError:
58+
pass
5659

60+
if type(value0) != type(value): # if types are different, then we consider them different
61+
return True
62+
elif isinstance(value0, np.ndarray) and np.any(value0 != value): # if they're a numpy array, special test
63+
return True
64+
elif value0 != value:
65+
return True
5766

58-
def diff_headers(files, fields):
59-
"""Iterates over all header fields of all files to find those that differ
67+
return False
6068

61-
Parameters
62-
----------
63-
files: a given list of files to be compared
64-
fields: the fields to be compared
6569

66-
Returns
67-
-------
68-
list
69-
header fields whose values differ across files
70-
"""
71-
72-
headers = []
73-
74-
for f in range(len(files)): # for each file
75-
for h in fields: # for each header
76-
77-
# each maneuver is encased in a try block after exceptions have previously occurred
78-
# get the particular header field within the particular file
79-
80-
try:
81-
field = files[f][h]
82-
83-
except ValueError:
84-
continue
85-
86-
# filter numpy arrays with a NaN value
87-
try:
88-
if np.all(np.isnan(field)):
89-
continue
90-
91-
except TypeError:
92-
pass
93-
94-
# compare current file with other files
95-
for i in files[f + 1:]:
96-
other_field = i[h]
97-
98-
# sometimes field.item doesn't work
99-
try:
100-
# converting bytes to be compared as strings
101-
if isinstance(field.item(0), bytes):
102-
field = field.item(0).decode("utf-8")
103-
104-
# converting np.ndarray to lists to remove ambiguity
105-
if isinstance(field, np.ndarray):
106-
field = field.tolist()
107-
108-
if isinstance(other_field.item(0), bytes):
109-
other_field = other_field.item(0).decode("utf-8")
110-
if isinstance(other_field, np.ndarray):
111-
other_field = other_field.tolist()
112-
113-
except AttributeError:
114-
continue
115-
116-
# if the header values of the two files are different, append
117-
if diff_values(field, other_field):
118-
headers.append(h)
119-
120-
if headers: # return a list of headers for the files whose values differ
121-
return headers
122-
123-
124-
def diff_header_fields(header_field, files):
125-
"""Iterates over a single header field of multiple files
126-
127-
Parameters
128-
----------
129-
header_field: a given header field
130-
files: the files to be compared
131-
132-
Returns
133-
-------
134-
list
135-
str for each value corresponding to each file's given header field
136-
"""
137-
138-
keyed_inputs = []
139-
140-
for i in files:
141-
142-
# each maneuver is encased in a try block after exceptions have previously occurred
143-
# get the particular header field within the particular file
144-
145-
try:
146-
field_value = i[header_field]
147-
except ValueError:
148-
continue
149-
150-
# compare different data types, return all values as soon as diff is found
151-
for x in files[1:]:
152-
try:
153-
data_diff = diff_values(str(x[header_field].dtype), str(field_value.dtype))
154-
155-
if data_diff:
156-
break
157-
except ValueError:
158-
continue
159-
160-
# string formatting of responses
161-
try:
162-
163-
# if differences are found among data types
164-
if data_diff:
165-
# accounting for how to arrange arrays
166-
if field_value.ndim < 1:
167-
keyed_inputs.append("{}@{}".format(field_value, field_value.dtype))
168-
elif field_value.ndim == 1:
169-
keyed_inputs.append("{}@{}".format(list(field_value), field_value.dtype))
170-
171-
# if no differences are found among data types
172-
else:
173-
if field_value.ndim < 1:
174-
keyed_inputs.append(field_value)
175-
elif field_value.ndim == 1:
176-
keyed_inputs.append(list(field_value))
177-
178-
except UnboundLocalError:
179-
continue
180-
181-
for i in range(len(keyed_inputs)):
182-
keyed_inputs[i] = str(keyed_inputs[i])
183-
184-
return keyed_inputs
185-
186-
187-
def get_headers_diff(file_headers, headers):
70+
def get_headers_diff(file_headers, names=None):
18871
"""Get difference between headers
18972
19073
Parameters
19174
----------
192-
file_headers: list of actual headers from files
193-
headers: list of header fields that differ
75+
file_headers: list of actual headers (dicts) from files
76+
names: list of header fields to test
19477
19578
Returns
19679
-------
19780
dict
19881
str: list for each header field which differs, return list of
19982
values per each file
20083
"""
201-
output = OrderedDict()
202-
203-
# if there are headers that differ
204-
if headers:
84+
difference = OrderedDict()
20585

206-
# for each header
207-
for header in headers:
86+
# for each header field
87+
for name in names:
88+
values = [header.get(name) for header in file_headers] # get corresponding value
20889

209-
# find the values corresponding to the files that differ
210-
val = diff_header_fields(header, file_headers)
90+
# if these values are different, store them in a dictionary
91+
if are_values_different(*values):
92+
difference[name] = values
21193

212-
# store these values in a dictionary
213-
if val:
214-
output[header] = val
215-
216-
return output
94+
return difference
21795

21896

21997
def get_data_md5sums(files):
@@ -252,16 +130,16 @@ def main():
252130
header_fields = file_headers[0].keys()
253131
else:
254132
header_fields = opts.header_fields.split(',')
255-
headers = diff_headers(file_headers, header_fields)
256-
diff = get_headers_diff(file_headers, headers)
133+
134+
diff = get_headers_diff(file_headers, header_fields)
257135
data_diff = get_data_md5sums(files)
258136

259137
if data_diff:
260138
diff['DATA(md5)'] = data_diff
261139

262140
if diff:
263141
print("These files are different.")
264-
print("{:<11}".format('Field'), end="")
142+
print("{:<15}".format('Field'), end="")
265143

266144
for f in files:
267145
output = ""
@@ -273,12 +151,12 @@ def main():
273151
output += f[i]
274152
i += 1
275153

276-
print("{:<45}".format(output), end="")
154+
print("{:<55}".format(output), end="")
277155

278156
print()
279157

280158
for key, value in diff.items():
281-
print("{:<11}".format(key), end="")
159+
print("{:<15}".format(key), end="")
282160

283161
for item in value:
284162
item_str = str(item)
@@ -289,7 +167,7 @@ def main():
289167
# and also replace some other invisible symbols with a question
290168
# mark
291169
item_str = re.sub('[\x00]', '?', item_str)
292-
print("{:<45}".format(item_str), end="")
170+
print("{:<55}".format(item_str), end="")
293171

294172
print()
295173

nibabel/tests/test_diff.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55
from __future__ import division, print_function, absolute_import
66

77
from os.path import (dirname, join as pjoin, abspath)
8+
import numpy as np
89

910
from hypothesis import given
1011
import hypothesis.strategies as st
1112

1213

1314
DATA_PATH = abspath(pjoin(dirname(__file__), 'data'))
1415

15-
from nibabel.cmdline.diff import diff_values
16+
from nibabel.cmdline.diff import are_values_different
1617

17-
# TODO: MAJOR TO DO IS TO FIGURE OUT HOW TO USE HYPOTHESIS FOR LONGER LIST
18-
# LENGTHS WHILE STILL CONTROLLING FOR OUTCOMES
18+
# TODO: MAJOR TO DO IS TO FIGURE OUT HOW TO USE HYPOTHESIS FOR LONGER LIST LENGTHS WHILE STILL CONTROLLING FOR OUTCOMES
1919

2020

2121
@given(st.data())
2222
def test_diff_values_int(data):
2323
x = data.draw(st.integers(), label='x')
24-
y = data.draw(st.integers(min_value = x + 1), label='x+1')
25-
z = data.draw(st.integers(max_value = x - 1), label='x-1')
24+
y = data.draw(st.integers(min_value=x + 1), label='x+1')
25+
z = data.draw(st.integers(max_value=x - 1), label='x-1')
2626

27-
assert not diff_values(x, x)
28-
assert diff_values(x, y)
29-
assert diff_values(x, z)
30-
assert diff_values(y, z)
27+
assert not are_values_different(x, x)
28+
assert are_values_different(x, y)
29+
assert are_values_different(x, z)
30+
assert are_values_different(y, z)
3131

3232

3333
@given(st.data())
@@ -36,10 +36,10 @@ def test_diff_values_float(data):
3636
y = data.draw(st.floats(min_value=1e8), label='y')
3737
z = data.draw(st.floats(max_value=-1e8), label='z')
3838

39-
assert not diff_values(x, x)
40-
assert diff_values(x, y)
41-
assert diff_values(x, z)
42-
assert diff_values(y, z)
39+
assert not are_values_different(x, x)
40+
assert are_values_different(x, y)
41+
assert are_values_different(x, z)
42+
assert are_values_different(y, z)
4343

4444

4545
@given(st.data())
@@ -48,10 +48,11 @@ def test_diff_values_mixed(data):
4848
type_int = data.draw(st.integers(), label='int')
4949
type_none = data.draw(st.none(), label='none')
5050

51-
assert diff_values(type_float, type_int)
52-
assert diff_values(type_float, type_none)
53-
assert diff_values(type_int, type_none)
54-
assert not diff_values(type_none, type_none)
51+
assert are_values_different(type_float, type_int)
52+
assert are_values_different(type_float, type_none)
53+
assert are_values_different(type_int, type_none)
54+
assert are_values_different(np.ndarray([0]), 'hey')
55+
assert not are_values_different(type_none, type_none)
5556

5657

5758
@given(st.data())
@@ -62,6 +63,6 @@ def test_diff_values_array(data):
6263
d = data.draw(st.lists(elements=st.floats(max_value=-1e8), min_size=1))
6364
# TODO: Figure out a way to include 0 in lists (arrays)
6465

65-
assert diff_values(a, b)
66-
assert diff_values(c, d)
67-
assert not diff_values(a, a)
66+
assert are_values_different(a, b)
67+
assert are_values_different(c, d)
68+
assert not are_values_different(a, a)

0 commit comments

Comments
 (0)