Skip to content

Commit ea84c9d

Browse files
committed
RF: refactor tck read for speed
Use byte operations to save array creation etc when separating streamlines in TCK format.
1 parent 72d146a commit ea84c9d

File tree

2 files changed

+57
-52
lines changed

2 files changed

+57
-52
lines changed

nibabel/streamlines/tck.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from .header import Field
2222

2323
MEGABYTE = 1024 * 1024
24-
BUFFER_SIZE = 1000000
2524

2625

2726
def create_empty_header():
@@ -342,8 +341,8 @@ def _read_header(fileobj):
342341

343342
return hdr
344343

345-
@staticmethod
346-
def _read(fileobj, header, buffer_size=4):
344+
@classmethod
345+
def _read(cls, fileobj, header, buffer_size=4):
347346
""" Return generator that reads TCK data from `fileobj` given `header`
348347
349348
Parameters
@@ -369,65 +368,60 @@ def _read(fileobj, header, buffer_size=4):
369368
buffer_size = int(buffer_size * MEGABYTE)
370369
buffer_size += coordinate_size - (buffer_size % coordinate_size)
371370

371+
# Markers for streamline end and file end
372+
fiber_marker = cls.FIBER_DELIMITER.astype(dtype).tostring()
373+
eof_marker = cls.EOF_DELIMITER.astype(dtype).tostring()
374+
372375
with Opener(fileobj) as f:
373376
start_position = f.tell()
374377

375378
# Set the file position at the beginning of the data.
376379
f.seek(header["_offset_data"], os.SEEK_SET)
377380

378381
eof = False
379-
buff = b""
380-
pts = []
381-
382-
i = 0
383-
384-
while not eof or not np.all(np.isinf(pts)):
385-
386-
if not eof:
387-
bytes_read = f.read(buffer_size)
388-
buff += bytes_read
389-
eof = len(bytes_read) == 0
382+
buffs = []
383+
n_streams = 0
390384

391-
# Read floats.
392-
pts = np.frombuffer(buff, dtype=dtype)
385+
while not eof:
393386

394-
# Convert data to little-endian if needed.
395-
if dtype != '<f4':
396-
pts = pts.astype('<f4')
397-
398-
pts = pts.reshape([-1, 3])
399-
idx_nan = np.arange(len(pts))[np.isnan(pts[:, 0])]
387+
bytes_read = f.read(buffer_size)
388+
buffs.append(bytes_read)
389+
eof = len(bytes_read) != buffer_size
400390

401391
# Make sure we've read enough to find a streamline delimiter.
402-
if len(idx_nan) == 0:
392+
if fiber_marker not in bytes_read:
403393
# If we've read the whole file, then fail.
404-
if eof and not np.all(np.isinf(pts)):
405-
msg = ("Cannot find a streamline delimiter. This file"
406-
" might be corrupted.")
407-
raise DataError(msg)
408-
409-
# Otherwise read a bit more.
410-
continue
411-
412-
nb_pts_total = 0
413-
idx_start = 0
414-
for idx_end in idx_nan:
415-
nb_pts = len(pts[idx_start:idx_end, :])
416-
nb_pts_total += nb_pts
417-
418-
if nb_pts > 0:
419-
yield pts[idx_start:idx_end, :]
420-
i += 1
421-
422-
idx_start = idx_end + 1
423-
424-
# Remove pts plus the first triplet of NaN.
425-
nb_tiplets_to_remove = nb_pts_total + len(idx_nan)
426-
nb_bytes_to_remove = nb_tiplets_to_remove * 3 * dtype.itemsize
427-
buff = buff[nb_bytes_to_remove:]
394+
if eof:
395+
# Could have minimal buffering, and have read only the
396+
# EOF delimiter
397+
buffs = [b''.join(buffs)]
398+
if not buffs[0] == eof_marker:
399+
raise DataError(
400+
"Cannot find a streamline delimiter. This file"
401+
" might be corrupted.")
402+
else:
403+
# Otherwise read a bit more.
404+
continue
405+
406+
all_parts = b''.join(buffs).split(fiber_marker)
407+
point_parts, buffs = all_parts[:-1], all_parts[-1:]
408+
point_parts = [p for p in point_parts if p != b'']
409+
410+
for point_part in point_parts:
411+
# Read floats.
412+
pts = np.frombuffer(point_part, dtype=dtype)
413+
# Enforce ability to write to underlying bytes object
414+
pts.flags.writeable = True
415+
# Convert data to little-endian if needed.
416+
yield pts.astype('<f4', copy=False).reshape([-1, 3])
417+
418+
n_streams += len(point_parts)
419+
420+
if not buffs[-1] == eof_marker:
421+
raise DataError('Expecting end-of-file marker ' 'inf inf inf')
428422

429423
# In case the 'count' field was not provided.
430-
header[Field.NB_STREAMLINES] = i
424+
header[Field.NB_STREAMLINES] = n_streams
431425

432426
# Set the file position where it was (in case it was already open).
433427
f.seek(start_position, os.SEEK_CUR)

nibabel/streamlines/tests/test_tck.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
from nibabel.externals.six import BytesIO
77
from nibabel.py3k import asbytes
88

9-
from nose.tools import assert_equal, assert_raises
10-
11-
from nibabel.testing import data_path
12-
from .test_tractogram import assert_tractogram_equal
139
from ..array_sequence import ArraySequence
1410
from ..tractogram import Tractogram
1511
from ..tractogram_file import DataError
1612

1713
from ..tck import TckFile
1814

15+
from nose.tools import assert_equal, assert_raises, assert_true
16+
from numpy.testing import assert_array_equal
17+
from nibabel.testing import data_path
18+
from .test_tractogram import assert_tractogram_equal
1919

2020
DATA = {}
2121

@@ -62,6 +62,17 @@ def test_load_simple_file(self):
6262
tck = TckFile(tractogram, header=hdr)
6363
assert_tractogram_equal(tck.tractogram, DATA['simple_tractogram'])
6464

65+
def test_writeable_data(self):
66+
data = DATA['simple_tractogram']
67+
for key in ('simple_tck_fname', 'simple_tck_big_endian_fname'):
68+
for lazy_load in [False, True]:
69+
tck = TckFile.load(DATA[key], lazy_load=lazy_load)
70+
for actual, expected_tgi in zip(tck.streamlines, data):
71+
assert_array_equal(actual, expected_tgi.streamline)
72+
# Test we can write to arrays
73+
assert_true(actual.flags.writeable)
74+
actual[0, 0] = 99
75+
6576
def test_load_simple_file_in_big_endian(self):
6677
for lazy_load in [False, True]:
6778
tck = TckFile.load(DATA['simple_tck_big_endian_fname'],

0 commit comments

Comments
 (0)