Skip to content

Commit 0fdcecf

Browse files
committed
BF+TEST - fix and extend trackvis scalar, property checks on write
1 parent b37565c commit 0fdcecf

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

nibabel/tests/test_trackvis.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,62 @@ def test_write():
4242
tv.write, out_f, [],{'hdr_size': 0})
4343

4444

45+
def test_write_scalars_props():
46+
# Test writing of scalar array with streamlines
47+
N = 6
48+
M = 2
49+
P = 4
50+
points = np.arange(N*3).reshape((N,3))
51+
scalars = np.arange(N*M).reshape((N,M)) + 100
52+
props = np.arange(P) + 1000
53+
# If scalars not same size for each point, error
54+
out_f = StringIO()
55+
streams = [(points, None, None),
56+
(points, scalars, None)]
57+
assert_raises(tv.DataError, tv.write, out_f, streams)
58+
out_f.seek(0)
59+
streams = [(points, np.zeros((N,M+1)), None),
60+
(points, scalars, None)]
61+
assert_raises(tv.DataError, tv.write, out_f, streams)
62+
# Or if scalars different N compared to points
63+
bad_scalars = np.zeros((N+1,M))
64+
out_f.seek(0)
65+
streams = [(points, bad_scalars, None),
66+
(points, bad_scalars, None)]
67+
assert_raises(tv.DataError, tv.write, out_f, streams)
68+
# Similarly properties must have the same length for each streamline
69+
out_f.seek(0)
70+
streams = [(points, scalars, None),
71+
(points, scalars, props)]
72+
assert_raises(tv.DataError, tv.write, out_f, streams)
73+
out_f.seek(0)
74+
streams = [(points, scalars, np.zeros((P+1,))),
75+
(points, scalars, props)]
76+
assert_raises(tv.DataError, tv.write, out_f, streams)
77+
# If all is OK, then we get back what we put in
78+
out_f.seek(0)
79+
streams = [(points, scalars, props),
80+
(points, scalars, props)]
81+
tv.write(out_f, streams)
82+
out_f.seek(0)
83+
back_streams, hdr = tv.read(out_f)
84+
for actual, expected in zip(streams, back_streams):
85+
for a_el, e_el in zip(actual, expected):
86+
assert_array_equal(a_el, e_el)
87+
88+
4589
def streams_equal(stream1, stream2):
46-
if not np.all(stream1[0] == stream1[0]):
90+
if not np.all(stream1[0] == stream2[0]):
4791
return False
4892
if stream1[1] is None:
4993
if not stream2[1] is None:
5094
return False
5195
if stream1[2] is None:
5296
if not stream2[2] is None:
5397
return False
54-
if not np.all(stream1[1] == stream1[1]):
98+
if not np.all(stream1[1] == stream2[1]):
5599
return False
56-
if not np.all(stream1[2] == stream1[2]):
100+
if not np.all(stream1[2] == stream2[2]):
57101
return False
58102
return True
59103

nibabel/trackvis.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def read(fileobj, as_generator=False):
108108
where M is the number of scalars per point
109109
#. properties : None or ndarray shape (P,)
110110
where P is the number of properties
111+
111112
hdr : structured array
112113
structured array with trackvis header fields
113114
@@ -218,7 +219,9 @@ def write(fileobj, streamlines, hdr_mapping=None, endianness=None):
218219
If `streamlines` has a ``len`` (for example, it is a list or a tuple),
219220
then we can write the number of streamlines into the header. Otherwise
220221
we write 0 for the number of streamlines (a valid trackvis header) and
221-
write streamlines into the file until the iterable is exhausted
222+
write streamlines into the file until the iterable is exhausted.
223+
M - the number of scalars - has to be the same for each streamline in
224+
`streamlines`. Similarly for P.
222225
hdr_mapping : None, ndarray or mapping, optional
223226
Information for filling header fields. Can be something
224227
dict-like (implementing ``items``) or a structured numpy array
@@ -281,13 +284,13 @@ def write(fileobj, streamlines, hdr_mapping=None, endianness=None):
281284
if not streams0 is None:
282285
pts, scalars, props = streams0
283286
# calculate number of scalars
284-
if scalars:
287+
if not scalars is None:
285288
n_s = scalars.shape[1]
286289
else:
287290
n_s = 0
288291
hdr['n_scalars'] = n_s
289292
# calculate number of properties
290-
if props:
293+
if not props is None:
291294
n_p = props.size
292295
hdr['n_properties'] = n_p
293296
else:
@@ -309,17 +312,23 @@ def write(fileobj, streamlines, hdr_mapping=None, endianness=None):
309312
# the endianness is OK.
310313
if pts.dtype != f4dt:
311314
pts = pts.astype(f4dt)
312-
if n_s:
315+
if n_s == 0:
316+
if not (scalars is None or len(scalars) == 0):
317+
raise DataError('Expecting 0 scalars per point')
318+
else:
313319
if scalars.shape != (n_pts, n_s):
314-
raise ValueError('Scalars should be shape (%s, %s)'
320+
raise DataError('Scalars should be shape (%s, %s)'
315321
% (n_pts, n_s))
316322
if scalars.dtype != f4dt:
317323
scalars = scalars.astype(f4dt)
318324
pts = np.c_[pts, scalars]
319325
fileobj.write(pts.tostring())
320-
if n_p:
326+
if n_p == 0:
327+
if not (props is None or len(props) == 0):
328+
raise DataError('Expecting 0 properties per point')
329+
else:
321330
if props.size != n_p:
322-
raise ValueError('Properties should be size %s' % n_p)
331+
raise DataError('Properties should be size %s' % n_p)
323332
if props.dtype != f4dt:
324333
props = props.astype(f4dt)
325334
fileobj.write(props.tostring())

0 commit comments

Comments
 (0)