Skip to content

Commit d21a980

Browse files
committed
ENH: speed up TRK loading by setting a suitable buffer size.
1 parent b37af09 commit d21a980

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def load(cls, filename):
547547
return seq
548548

549549

550-
def create_arraysequences_from_generator(gen, n):
550+
def create_arraysequences_from_generator(gen, n, buffer_sizes=None):
551551
""" Creates :class:`ArraySequence` objects from a generator yielding tuples
552552
553553
Parameters
@@ -557,8 +557,13 @@ def create_arraysequences_from_generator(gen, n):
557557
array sequences.
558558
n : int
559559
Number of :class:`ArraySequences` object to create.
560+
buffer_sizes : list of float, optional
561+
Sizes (in Mb) for each ArraySequence's buffer.
560562
"""
561-
seqs = [ArraySequence() for _ in range(n)]
563+
if buffer_sizes is None:
564+
buffer_sizes = [4] * n
565+
566+
seqs = [ArraySequence(buffer_size=size) for size in buffer_sizes]
562567
for data in gen:
563568
for i, seq in enumerate(seqs):
564569
if data[i].nbytes > 0:

nibabel/streamlines/trk.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,23 @@ def _read():
372372
tractogram = LazyTractogram.from_data_func(_read)
373373

374374
else:
375+
376+
# Speed up loading by guessing a suitable buffer size.
377+
with Opener(fileobj) as f:
378+
old_file_position = f.tell()
379+
f.seek(0, os.SEEK_END)
380+
size = f.tell()
381+
f.seek(old_file_position, os.SEEK_SET)
382+
383+
# Buffer size is in mega bytes.
384+
mbytes = size // (1024 * 1024)
385+
sizes = [mbytes, 4, 4]
386+
if hdr["nb_scalars_per_point"] > 0:
387+
sizes = [mbytes // 2, mbytes // 2, 4]
388+
375389
trk_reader = cls._read(fileobj, hdr)
376-
arr_seqs = create_arraysequences_from_generator(trk_reader, n=3)
390+
arr_seqs = create_arraysequences_from_generator(trk_reader, n=3,
391+
buffer_sizes=sizes)
377392
streamlines, scalars, properties = arr_seqs
378393
properties = np.asarray(properties) # Actually a 2d array.
379394
tractogram = Tractogram(streamlines)

0 commit comments

Comments
 (0)