Skip to content

Commit 97ead00

Browse files
committed
added cmdline functionality for modifying datatype used for file data comparisons
1 parent cd85e09 commit 97ead00

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

nibabel/cmdline/diff.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def get_opt_parser():
5454
" If --data-max-abs-diff is also specified, only the data points "
5555
" with absolute difference greater than that value would be "
5656
" considered for relative difference check."),
57+
Option("--dt", "--datatype",
58+
dest="dtype",
59+
default=np.float64,
60+
help="Enter a numpy datatype such as 'float32'.")
5761
])
5862

5963
return p
@@ -116,7 +120,7 @@ def get_headers_diff(file_headers, names=None):
116120
return difference
117121

118122

119-
def get_data_hash_diff(files):
123+
def get_data_hash_diff(files, dtype=np.float64):
120124
"""Get difference between md5 values of data
121125
122126
Parameters
@@ -130,7 +134,7 @@ def get_data_hash_diff(files):
130134
"""
131135

132136
md5sums = [
133-
hashlib.md5(np.ascontiguousarray(nib.load(f).get_fdata())).hexdigest()
137+
hashlib.md5(np.ascontiguousarray(nib.load(f).get_fdata(dtype=dtype))).hexdigest()
134138
for f in files
135139
]
136140

@@ -140,7 +144,7 @@ def get_data_hash_diff(files):
140144
return md5sums
141145

142146

143-
def get_data_diff(files, max_abs=0, max_rel=0):
147+
def get_data_diff(files, max_abs=0, max_rel=0, dtype=np.float64):
144148
"""Get difference between data
145149
146150
Parameters
@@ -153,6 +157,8 @@ def get_data_diff(files, max_abs=0, max_rel=0):
153157
Maximal relative (`abs(diff)/mean(diff)`) difference to tolerate.
154158
If `max_abs` is specified, then those data points with lesser than that
155159
absolute difference, are not considered for relative difference testing
160+
dtype: np, optional
161+
Datatype to be used when extracting data from files
156162
157163
Returns
158164
-------
@@ -167,7 +173,7 @@ def get_data_diff(files, max_abs=0, max_rel=0):
167173
"""
168174

169175
# we are doomed to keep them in RAM now
170-
data = [f if isinstance(f, np.ndarray) else nib.load(f).get_fdata()
176+
data = [f if isinstance(f, np.ndarray) else nib.load(f).get_fdata(dtype=dtype)
171177
for f in files]
172178
diffs = OrderedDict()
173179
for i, d1 in enumerate(data[:-1]):
@@ -268,7 +274,7 @@ def display_diff(files, diff):
268274
return output
269275

270276

271-
def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=None):
277+
def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=None, dtype=np.float64):
272278
assert len(files) >= 2, "Please enter at least two files"
273279

274280
file_headers = [nib.load(f).header for f in files]
@@ -282,13 +288,14 @@ def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=N
282288

283289
diff = get_headers_diff(file_headers, header_fields)
284290

285-
data_md5_diffs = get_data_hash_diff(files)
291+
data_md5_diffs = get_data_hash_diff(files, dtype)
286292
if data_md5_diffs:
287293
# provide details, possibly triggering the ignore of the difference
288294
# in data
289295
data_diffs = get_data_diff(files,
290296
max_abs=data_max_abs_diff,
291-
max_rel=data_max_rel_diff)
297+
max_rel=data_max_rel_diff,
298+
dtype=dtype)
292299
if data_diffs:
293300
diff['DATA(md5)'] = data_md5_diffs
294301
diff.update(data_diffs)
@@ -313,7 +320,8 @@ def main(args=None, out=None):
313320
files,
314321
header_fields=opts.header_fields,
315322
data_max_abs_diff=opts.data_max_abs_diff,
316-
data_max_rel_diff=opts.data_max_rel_diff
323+
data_max_rel_diff=opts.data_max_rel_diff,
324+
dtype=opts.dtype
317325
)
318326

319327
if files_diff:

0 commit comments

Comments
 (0)