@@ -54,6 +54,10 @@ def get_opt_parser():
54
54
" If --data-max-abs-diff is also specified, only the data points "
55
55
" with absolute difference greater than that value would be "
56
56
" considered for relative difference check." ),
57
+ Option ("--dt" , "--datatype" ,
58
+ dest = "dtype" ,
59
+ default = np .float64 ,
60
+ help = "Enter a numpy datatype such as 'float32'." )
57
61
])
58
62
59
63
return p
@@ -116,7 +120,7 @@ def get_headers_diff(file_headers, names=None):
116
120
return difference
117
121
118
122
119
- def get_data_hash_diff (files ):
123
+ def get_data_hash_diff (files , dtype = np . float64 ):
120
124
"""Get difference between md5 values of data
121
125
122
126
Parameters
@@ -130,7 +134,7 @@ def get_data_hash_diff(files):
130
134
"""
131
135
132
136
md5sums = [
133
- hashlib .md5 (np .ascontiguousarray (nib .load (f ).get_fdata ())).hexdigest ()
137
+ hashlib .md5 (np .ascontiguousarray (nib .load (f ).get_fdata (dtype = dtype ))).hexdigest ()
134
138
for f in files
135
139
]
136
140
@@ -140,7 +144,7 @@ def get_data_hash_diff(files):
140
144
return md5sums
141
145
142
146
143
- def get_data_diff (files , max_abs = 0 , max_rel = 0 ):
147
+ def get_data_diff (files , max_abs = 0 , max_rel = 0 , dtype = np . float64 ):
144
148
"""Get difference between data
145
149
146
150
Parameters
@@ -153,6 +157,8 @@ def get_data_diff(files, max_abs=0, max_rel=0):
153
157
Maximal relative (`abs(diff)/mean(diff)`) difference to tolerate.
154
158
If `max_abs` is specified, then those data points with lesser than that
155
159
absolute difference, are not considered for relative difference testing
160
+ dtype: np, optional
161
+ Datatype to be used when extracting data from files
156
162
157
163
Returns
158
164
-------
@@ -167,7 +173,7 @@ def get_data_diff(files, max_abs=0, max_rel=0):
167
173
"""
168
174
169
175
# we are doomed to keep them in RAM now
170
- data = [f if isinstance (f , np .ndarray ) else nib .load (f ).get_fdata ()
176
+ data = [f if isinstance (f , np .ndarray ) else nib .load (f ).get_fdata (dtype = dtype )
171
177
for f in files ]
172
178
diffs = OrderedDict ()
173
179
for i , d1 in enumerate (data [:- 1 ]):
@@ -268,7 +274,7 @@ def display_diff(files, diff):
268
274
return output
269
275
270
276
271
- def diff (files , header_fields = 'all' , data_max_abs_diff = None , data_max_rel_diff = None ):
277
+ def diff (files , header_fields = 'all' , data_max_abs_diff = None , data_max_rel_diff = None , dtype = np . float64 ):
272
278
assert len (files ) >= 2 , "Please enter at least two files"
273
279
274
280
file_headers = [nib .load (f ).header for f in files ]
@@ -282,13 +288,14 @@ def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=N
282
288
283
289
diff = get_headers_diff (file_headers , header_fields )
284
290
285
- data_md5_diffs = get_data_hash_diff (files )
291
+ data_md5_diffs = get_data_hash_diff (files , dtype )
286
292
if data_md5_diffs :
287
293
# provide details, possibly triggering the ignore of the difference
288
294
# in data
289
295
data_diffs = get_data_diff (files ,
290
296
max_abs = data_max_abs_diff ,
291
- max_rel = data_max_rel_diff )
297
+ max_rel = data_max_rel_diff ,
298
+ dtype = dtype )
292
299
if data_diffs :
293
300
diff ['DATA(md5)' ] = data_md5_diffs
294
301
diff .update (data_diffs )
@@ -313,7 +320,8 @@ def main(args=None, out=None):
313
320
files ,
314
321
header_fields = opts .header_fields ,
315
322
data_max_abs_diff = opts .data_max_abs_diff ,
316
- data_max_rel_diff = opts .data_max_rel_diff
323
+ data_max_rel_diff = opts .data_max_rel_diff ,
324
+ dtype = opts .dtype
317
325
)
318
326
319
327
if files_diff :
0 commit comments