Skip to content

Commit be485c6

Browse files
committed
Improved speed for loading TRK
1 parent 17bd66c commit be485c6

File tree

5 files changed

+206
-39
lines changed

5 files changed

+206
-39
lines changed

Changelog

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ References like "pr/298" refer to github pull request numbers.
3636
are raising a DataError if the track is truncated when ``strict=True``
3737
(the default), rather than a TypeError when trying to create the points
3838
array.
39+
* New API for managing streamlines and their different file formats. This
40+
adds a new module ``nibabel.streamlines`` that will eventually deprecate
41+
the current trackvis reader found in ``nibabel.trackvis``.
3942

4043
* 2.0.2 (Monday 23 November 2015)
4144

@@ -251,7 +254,7 @@ References like "pr/298" refer to github pull request numbers.
251254
the ability to transform to the image with data closest to the cononical
252255
image orientation (first axis left-to-right, second back-to-front, third
253256
down-to-up) (MB, Jonathan Taylor)
254-
* Gifti format read and write support (preliminary) (Stephen Gerhard)
257+
* Gifti format read and write support (preliminary) (Stephen Gerhard)
255258
* Added utilities to use nipy-style data packages, by rip then edit of nipy
256259
data package code (MB)
257260
* Some improvements to release support (Jarrod Millman, MB, Fernando Perez)
@@ -469,7 +472,7 @@ visiting the URL::
469472

470473
* Removed functionality for "NiftiImage.save() raises an IOError
471474
exception when writing the image file fails." (Yaroslav Halchenko)
472-
* Added ability to force a filetype when setting the filename or saving
475+
* Added ability to force a filetype when setting the filename or saving
473476
a file.
474477
* Reverse the order of the 'header' and 'load' argument in the NiftiImage
475478
constructor. 'header' is now first as it seems to be used more often.
@@ -481,7 +484,7 @@ visiting the URL::
481484

482485
* 0.20070301.2 (Thu, 1 Mar 2007)
483486

484-
* Fixed wrong link to the source tarball in README.html.
487+
* Fixed wrong link to the source tarball in README.html.
485488

486489

487490
* 0.20070301.1 (Thu, 1 Mar 2007)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
""" Benchmarks for load and save of streamlines
2+
3+
Run benchmarks with::
4+
5+
import nibabel as nib
6+
nib.bench()
7+
8+
If you have doctests enabled by default in nose (with a noserc file or
9+
environment variable), and you have a numpy version <= 1.6.1, this will also run
10+
the doctests, let's hope they pass.
11+
12+
Run this benchmark with:
13+
14+
nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py
15+
"""
16+
from __future__ import division, print_function
17+
18+
import numpy as np
19+
20+
from nibabel.externals.six.moves import zip
21+
from nibabel.tmpdirs import InTemporaryDirectory
22+
23+
from numpy.testing import assert_array_equal
24+
from nibabel.streamlines import Tractogram
25+
from nibabel.streamlines import TrkFile
26+
27+
import nibabel as nib
28+
import nibabel.trackvis as tv
29+
30+
from numpy.testing import measure
31+
32+
33+
def bench_load_trk():
34+
rng = np.random.RandomState(42)
35+
dtype = 'float32'
36+
NB_STREAMLINES = 5000
37+
NB_POINTS = 1000
38+
points = [rng.rand(NB_POINTS, 3).astype(dtype)
39+
for i in range(NB_STREAMLINES)]
40+
scalars = [rng.rand(NB_POINTS, 10).astype(dtype)
41+
for i in range(NB_STREAMLINES)]
42+
43+
repeat = 10
44+
45+
with InTemporaryDirectory():
46+
trk_file = "tmp.trk"
47+
tractogram = Tractogram(points, affine_to_rasmm=np.eye(4))
48+
TrkFile(tractogram).save(trk_file)
49+
50+
loaded_streamlines_old = [d[0]-0.5 for d in tv.read(trk_file, points_space="rasmm")[0]]
51+
mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat)
52+
print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old))
53+
54+
loaded_streamlines_new = nib.streamlines.load(trk_file, lazy_load=False).streamlines
55+
mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat)
56+
print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new))
57+
print("Speedup of %2f" % (mtime_old/mtime_new))
58+
59+
for s1, s2 in zip(loaded_streamlines_new, loaded_streamlines_old):
60+
assert_array_equal(s1, s2)
61+
62+
# Points and scalars
63+
with InTemporaryDirectory():
64+
65+
trk_file = "tmp.trk"
66+
tractogram = Tractogram(points,
67+
data_per_point={'scalars': scalars},
68+
affine_to_rasmm=np.eye(4))
69+
TrkFile(tractogram).save(trk_file)
70+
71+
mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat)
72+
print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old))
73+
74+
mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat)
75+
print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new))
76+
print("Speedup of %2f" % (mtime_old/mtime_new))
77+
78+
79+
if __name__ == '__main__':
80+
bench_load_trk()

nibabel/streamlines/array_sequence.py

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,40 +60,83 @@ def __init__(self, iterable=None, buffer_size=4):
6060
self._is_view = True
6161
return
6262

63-
# Add elements of the iterable.
63+
try:
64+
# If possible try pre-allocating memory.
65+
if len(iterable) > 0:
66+
first_element = np.asarray(iterable[0])
67+
n_elements = np.sum([len(iterable[i])
68+
for i in range(len(iterable))])
69+
new_shape = (n_elements,) + first_element.shape[1:]
70+
self._data = np.empty(new_shape, dtype=first_element.dtype)
71+
except TypeError:
72+
pass
73+
74+
# Initialize the `ArraySequence` object from iterable's item.
75+
coroutine = self._extend_using_coroutine()
76+
coroutine.send(None) # Run until the first yield.
77+
78+
for e in iterable:
79+
coroutine.send(e)
80+
81+
coroutine.close() # Terminate coroutine.
82+
83+
def _extend_using_coroutine(self, buffer_size=4):
84+
""" Creates a coroutine allowing to append elements.
85+
86+
Parameters
87+
----------
88+
buffer_size : float, optional
89+
Size (in Mb) for memory pre-allocation.
90+
91+
Returns
92+
-------
93+
coroutine
94+
Coroutine object which expects the values to be appended to this
95+
array sequence.
96+
97+
Notes
98+
-----
99+
This method is essential for
100+
:func:`create_arraysequences_from_generator` as it allows for an
101+
efficient way of creating multiple array sequences in a hyperthreaded
102+
fashion and still benefit from the memory buffering. Whitout this
103+
method the alternative would be to use :meth:`append` which does
104+
not have such buffering mechanism and thus is at least one order of
105+
magnitude slower.
106+
"""
64107
offsets = []
65108
lengths = []
66-
# Initialize the `ArraySequence` object from iterable's item.
67-
offset = 0
68-
for i, e in enumerate(iterable):
69-
e = np.asarray(e)
70-
if i == 0:
71-
try:
72-
n_elements = np.sum([len(iterable[i])
73-
for i in range(len(iterable))])
74-
new_shape = (n_elements,) + e.shape[1:]
75-
except TypeError:
76-
# Can't get the number of elements in iterable. So,
77-
# we use a memory buffer while building the ArraySequence.
109+
110+
offset = 0 if len(self) == 0 else self._offsets[-1] + self._lengths[-1]
111+
try:
112+
first_element = True
113+
while True:
114+
e = (yield)
115+
e = np.asarray(e)
116+
if first_element:
117+
first_element = False
78118
n_rows_buffer = int(buffer_size * 1024**2 // e.nbytes)
79119
new_shape = (n_rows_buffer,) + e.shape[1:]
120+
if len(self) == 0:
121+
self._data = np.empty(new_shape, dtype=e.dtype)
80122

81-
self._data = np.empty(new_shape, dtype=e.dtype)
123+
end = offset + len(e)
124+
if end > len(self._data):
125+
# Resize needed, adding `len(e)` items plus some buffer.
126+
nb_points = len(self._data)
127+
nb_points += len(e) + n_rows_buffer
128+
self._data.resize((nb_points,) + self.common_shape)
82129

83-
end = offset + len(e)
84-
if end > len(self._data):
85-
# Resize needed, adding `len(e)` items plus some buffer.
86-
nb_points = len(self._data)
87-
nb_points += len(e) + n_rows_buffer
88-
self._data.resize((nb_points,) + self.common_shape)
130+
offsets.append(offset)
131+
lengths.append(len(e))
132+
self._data[offset:offset + len(e)] = e
133+
offset += len(e)
89134

90-
offsets.append(offset)
91-
lengths.append(len(e))
92-
self._data[offset:offset + len(e)] = e
93-
offset += len(e)
135+
except GeneratorExit:
136+
pass
94137

95-
self._offsets = np.asarray(offsets)
96-
self._lengths = np.asarray(lengths)
138+
self._offsets = np.concatenate([self._offsets, offsets], axis=0)
139+
self._lengths = np.concatenate([self._lengths, lengths], axis=0)
97140

98141
# Clear unused memory.
99142
self._data.resize((offset,) + self.common_shape)
@@ -266,13 +309,6 @@ def __getitem__(self, idx):
266309
seq._is_view = True
267310
return seq
268311

269-
# for name, slice_ in data_per_point_slice.items():
270-
# seq = ArraySequence()
271-
# seq._data = scalars._data[:, slice_]
272-
# seq._offsets = scalars._offsets
273-
# seq._lengths = scalars._lengths
274-
# tractogram.data_per_point[name] = seq
275-
276312
raise TypeError("Index must be either an int, a slice, a list of int"
277313
" or a ndarray of bool! Not " + str(type(idx)))
278314

@@ -320,10 +356,27 @@ def load(cls, filename):
320356

321357
def create_arraysequences_from_generator(gen, n):
322358
""" Creates :class:`ArraySequence` objects from a generator yielding tuples
359+
360+
Parameters
361+
----------
362+
gen : generator
363+
Generator yielding a size `n` tuple containing the values to put in the
364+
array sequences.
365+
n : int
366+
Number of :class:`ArraySequences` object to create.
323367
"""
324368
seqs = [ArraySequence() for _ in range(n)]
369+
coroutines = [seq._extend_using_coroutine() for seq in seqs]
370+
371+
for coroutine in coroutines:
372+
coroutine.send(None)
373+
325374
for data in gen:
326-
for i, seq in enumerate(seqs):
327-
seq.append(data[i])
375+
for i, coroutine in enumerate(coroutines):
376+
if data[i].nbytes > 0:
377+
coroutine.send(data[i])
378+
379+
for coroutine in coroutines:
380+
coroutine.close()
328381

329382
return seqs

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,37 @@ def test_arraysequence_extend(self):
200200
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
201201
assert_raises(ValueError, seq.extend, data)
202202

203+
def test_arraysequence_extend_using_coroutine(self):
204+
new_data = generate_data(nb_arrays=10,
205+
common_shape=SEQ_DATA['seq'].common_shape,
206+
rng=SEQ_DATA['rng'])
207+
208+
# Extend with an empty list.
209+
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
210+
coroutine = seq._extend_using_coroutine()
211+
coroutine.send(None)
212+
coroutine.close()
213+
check_arr_seq(seq, SEQ_DATA['data'])
214+
215+
# Extend with a list of ndarrays.
216+
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
217+
coroutine = seq._extend_using_coroutine()
218+
coroutine.send(None)
219+
for e in new_data:
220+
coroutine.send(e)
221+
coroutine.close()
222+
check_arr_seq(seq, SEQ_DATA['data'] + new_data)
223+
224+
# Extend with elements of different shape.
225+
data = generate_data(nb_arrays=10,
226+
common_shape=SEQ_DATA['seq'].common_shape*2,
227+
rng=SEQ_DATA['rng'])
228+
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
229+
230+
coroutine = seq._extend_using_coroutine()
231+
coroutine.send(None)
232+
assert_raises(ValueError, coroutine.send, data[0])
233+
203234
def test_arraysequence_getitem(self):
204235
# Get one item
205236
for i, e in enumerate(SEQ_DATA['seq']):

nibabel/streamlines/trk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def save(self, fileobj):
425425
property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE,
426426
dtype='S20')
427427
for i, name in enumerate(data_for_streamline_keys):
428-
# Use the last to bytes of the name to store the number of
428+
# Use the last two bytes of the name to store the number of
429429
# values associated to this data_for_streamline.
430430
nb_values = data_for_streamline[name].shape[-1]
431431
property_name[i] = encode_value_in_name(nb_values, name)

0 commit comments

Comments
 (0)