Skip to content

Commit 859143b

Browse files
committed
Fixes #586
1 parent b8545ef commit 859143b

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

nibabel/streamlines/tck.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import os
99
import warnings
10+
import itertools
1011

1112
import numpy as np
1213

@@ -191,8 +192,18 @@ def save(self, fileobj):
191192
# Write temporary header that we will update at the end
192193
self._write_header(f, header)
193194

195+
# Make sure streamlines are in rasmm.
196+
tractogram = self.tractogram.to_world(lazy=True)
197+
# Assume looping over the streamlines can be done only once.
198+
tractogram = iter(tractogram)
199+
194200
try:
195-
first_item = next(iter(self.tractogram))
201+
# Use the first element to check
202+
# 1) the tractogram is not empty;
203+
# 2) quantity of information saved along each streamline.
204+
first_item = next(tractogram)
205+
# Put back the first element at its place.
206+
tractogram = itertools.chain([first_item], tractogram)
196207
except StopIteration:
197208
# Empty tractogram
198209
header[Field.NB_STREAMLINES] = 0
@@ -216,9 +227,6 @@ def save(self, fileobj):
216227
" alongside points. Dropping: {}".format(keys))
217228
warnings.warn(msg, DataWarning)
218229

219-
# Make sure streamlines are in rasmm.
220-
tractogram = self.tractogram.to_world(lazy=True)
221-
222230
for t in tractogram:
223231
data = np.r_[t.streamline, self.FIBER_DELIMITER]
224232
f.write(data.astype(dtype).tostring())

nibabel/streamlines/tests/test_streamlines.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,19 @@ def test_load_unknown_format(self):
272272

273273
def test_save_unknown_format(self):
274274
assert_raises(ValueError, nib.streamlines.save, Tractogram(), "")
275+
276+
def test_save_from_generator(self):
277+
tractogram = Tractogram(DATA['streamlines'],
278+
affine_to_rasmm=np.eye(4))
279+
280+
# Just to create a generator
281+
for ext, _ in FORMATS.items():
282+
filtered = (s for s in tractogram.streamlines if True)
283+
lazy_tractogram = LazyTractogram(lambda: filtered,
284+
affine_to_rasmm=np.eye(4))
285+
286+
with InTemporaryDirectory():
287+
filename = 'streamlines' + ext
288+
nib.streamlines.save(lazy_tractogram, filename)
289+
tfile = nib.streamlines.load(filename, lazy_load=False)
290+
assert_tractogram_equal(tfile.tractogram, tractogram)

nibabel/streamlines/trk.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import os
77
import struct
8-
import warnings
98
import string
9+
import warnings
10+
import itertools
1011

1112
import numpy as np
1213
import nibabel as nib
@@ -423,8 +424,20 @@ def save(self, fileobj):
423424
i4_dtype = np.dtype("<i4") # Always save in little-endian.
424425
f4_dtype = np.dtype("<f4") # Always save in little-endian.
425426

427+
# Make sure streamlines are in rasmm then send them to voxmm.
428+
tractogram = self.tractogram.to_world(lazy=True)
429+
affine_to_trackvis = get_affine_rasmm_to_trackvis(header)
430+
tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True)
431+
# Assume looping over the streamlines can be done only once.
432+
tractogram = iter(tractogram)
433+
426434
try:
427-
first_item = next(iter(self.tractogram))
435+
# Use the first element to check
436+
# 1) the tractogram is not empty;
437+
# 2) quantity of information saved along each streamline.
438+
first_item = next(tractogram)
439+
# Put back the first element at its place.
440+
tractogram = itertools.chain([first_item], tractogram)
428441
except StopIteration:
429442
# Empty tractogram
430443
header[Field.NB_STREAMLINES] = 0
@@ -470,11 +483,6 @@ def save(self, fileobj):
470483
scalar_name[i] = encode_value_in_name(nb_values, name)
471484
header['scalar_name'][:] = scalar_name
472485

473-
# Make sure streamlines are in rasmm then send them to voxmm.
474-
tractogram = self.tractogram.to_world(lazy=True)
475-
affine_to_trackvis = get_affine_rasmm_to_trackvis(header)
476-
tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True)
477-
478486
for t in tractogram:
479487
if any((len(d) != len(t.streamline)
480488
for d in t.data_for_points.values())):

0 commit comments

Comments
 (0)