Skip to content

Commit 5b732b9

Browse files
committed
BF+TEST - fix and extend trackvis scalar, property checks on write
1 parent 870c036 commit 5b732b9

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

nibabel/tests/test_trackvis.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,62 @@ def test_write():
4242
tv.write, out_f, [],{'hdr_size': 0})
4343

4444

45+
def test_write_scalars_props():
46+
# Test writing of scalar array with streamlines
47+
N = 6
48+
M = 2
49+
P = 4
50+
points = np.arange(N*3).reshape((N,3))
51+
scalars = np.arange(N*M).reshape((N,M)) + 100
52+
props = np.arange(P) + 1000
53+
# If scalars not same size for each point, error
54+
out_f = StringIO()
55+
streams = [(points, None, None),
56+
(points, scalars, None)]
57+
assert_raises(tv.DataError, tv.write, out_f, streams)
58+
out_f.seek(0)
59+
streams = [(points, np.zeros((N,M+1)), None),
60+
(points, scalars, None)]
61+
assert_raises(tv.DataError, tv.write, out_f, streams)
62+
# Or if scalars different N compared to points
63+
bad_scalars = np.zeros((N+1,M))
64+
out_f.seek(0)
65+
streams = [(points, bad_scalars, None),
66+
(points, bad_scalars, None)]
67+
assert_raises(tv.DataError, tv.write, out_f, streams)
68+
# Similarly properties must have the same length for each streamline
69+
out_f.seek(0)
70+
streams = [(points, scalars, None),
71+
(points, scalars, props)]
72+
assert_raises(tv.DataError, tv.write, out_f, streams)
73+
out_f.seek(0)
74+
streams = [(points, scalars, np.zeros((P+1,))),
75+
(points, scalars, props)]
76+
assert_raises(tv.DataError, tv.write, out_f, streams)
77+
# If all is OK, then we get back what we put in
78+
out_f.seek(0)
79+
streams = [(points, scalars, props),
80+
(points, scalars, props)]
81+
tv.write(out_f, streams)
82+
out_f.seek(0)
83+
back_streams, hdr = tv.read(out_f)
84+
for actual, expected in zip(streams, back_streams):
85+
for a_el, e_el in zip(actual, expected):
86+
assert_array_equal(a_el, e_el)
87+
88+
4589
def streams_equal(stream1, stream2):
46-
if not np.all(stream1[0] == stream1[0]):
90+
if not np.all(stream1[0] == stream2[0]):
4791
return False
4892
if stream1[1] is None:
4993
if not stream2[1] is None:
5094
return False
5195
if stream1[2] is None:
5296
if not stream2[2] is None:
5397
return False
54-
if not np.all(stream1[1] == stream1[1]):
98+
if not np.all(stream1[1] == stream2[1]):
5599
return False
56-
if not np.all(stream1[2] == stream1[2]):
100+
if not np.all(stream1[2] == stream2[2]):
57101
return False
58102
return True
59103

nibabel/trackvis.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def read(fileobj, as_generator=False):
108108
where M is the number of scalars per point
109109
#. properties : None or ndarray shape (P,)
110110
where P is the number of properties
111+
111112
hdr : structured array
112113
structured array with trackvis header fields
113114
@@ -218,7 +219,9 @@ def write(fileobj, streamlines, hdr_mapping=None, endianness=None):
218219
If `streamlines` has a ``len`` (for example, it is a list or a tuple),
219220
then we can write the number of streamlines into the header. Otherwise
220221
we write 0 for the number of streamlines (a valid trackvis header) and
221-
write streamlines into the file until the iterable is exhausted
222+
write streamlines into the file until the iterable is exhausted.
223+
M - the number of scalars - has to be the same for each streamline in
224+
`streamlines`. Similarly for P.
222225
hdr_mapping : None, ndarray or mapping, optional
223226
Information for filling header fields. Can be something
224227
dict-like (implementing ``items``) or a structured numpy array
@@ -255,13 +258,13 @@ def write(fileobj, streamlines, hdr_mapping=None, endianness=None):
255258
if not streams0 is None:
256259
pts, scalars, props = streams0
257260
# calculate number of scalars
258-
if scalars:
261+
if not scalars is None:
259262
n_s = scalars.shape[1]
260263
else:
261264
n_s = 0
262265
hdr['n_scalars'] = n_s
263266
# calculate number of properties
264-
if props:
267+
if not props is None:
265268
n_p = props.size
266269
hdr['n_properties'] = n_p
267270
else:
@@ -283,17 +286,23 @@ def write(fileobj, streamlines, hdr_mapping=None, endianness=None):
283286
# the endianness is OK.
284287
if pts.dtype != f4dt:
285288
pts = pts.astype(f4dt)
286-
if n_s:
289+
if n_s == 0:
290+
if not (scalars is None or len(scalars) == 0):
291+
raise DataError('Expecting 0 scalars per point')
292+
else:
287293
if scalars.shape != (n_pts, n_s):
288-
raise ValueError('Scalars should be shape (%s, %s)'
294+
raise DataError('Scalars should be shape (%s, %s)'
289295
% (n_pts, n_s))
290296
if scalars.dtype != f4dt:
291297
scalars = scalars.astype(f4dt)
292298
pts = np.c_[pts, scalars]
293299
fileobj.write(pts.tostring())
294-
if n_p:
300+
if n_p == 0:
301+
if not (props is None or len(props) == 0):
302+
raise DataError('Expecting 0 properties per point')
303+
else:
295304
if props.size != n_p:
296-
raise ValueError('Properties should be size %s' % n_p)
305+
raise DataError('Properties should be size %s' % n_p)
297306
if props.dtype != f4dt:
298307
props = props.astype(f4dt)
299308
fileobj.write(props.tostring())

0 commit comments

Comments
 (0)