diff --git a/nibabel/streamlines/tck.py b/nibabel/streamlines/tck.py index 823a88b5cf..c9bba94a6e 100644 --- a/nibabel/streamlines/tck.py +++ b/nibabel/streamlines/tck.py @@ -8,7 +8,6 @@ import warnings import numpy as np -from numpy.compat.py3k import asbytes, asstr from nibabel.openers import Opener @@ -44,7 +43,7 @@ class TckFile(TractogramFile): .. [#] http://nipy.org/nibabel/coordinate_systems.html#voxel-coordinates-are-in-voxel-space """ # Constants - MAGIC_NUMBER = "mrtrix tracks" + MAGIC_NUMBER = b"mrtrix tracks" SUPPORTS_DATA_PER_POINT = False # Not yet SUPPORTS_DATA_PER_STREAMLINE = False # Not yet @@ -94,7 +93,7 @@ def is_correct_format(cls, fileobj): magic_number = f.read(len(cls.MAGIC_NUMBER)) f.seek(-len(cls.MAGIC_NUMBER), os.SEEK_CUR) - return asstr(magic_number) == cls.MAGIC_NUMBER + return magic_number == cls.MAGIC_NUMBER @classmethod def create_empty_header(cls): @@ -230,7 +229,7 @@ def save(self, fileobj): header[Field.NB_STREAMLINES] = nb_streamlines # Add the EOF_DELIMITER. - f.write(asbytes(self.EOF_DELIMITER.tobytes())) + f.write(self.EOF_DELIMITER.tobytes()) self._finalize_header(f, header, offset=beginning) @staticmethod @@ -251,13 +250,11 @@ def _write_header(fileobj, header): "count", "datatype", "file"] # Fields being replaced. lines = [] - lines.append(asstr(header[Field.MAGIC_NUMBER])) lines.append(f"count: {header[Field.NB_STREAMLINES]:010}") lines.append("datatype: Float32LE") # Always Float32LE. lines.extend([f"{k}: {v}" for k, v in header.items() if k not in exclude and not k.startswith("_")]) - lines.append("file: . ") # Manually add this last field. out = "\n".join(lines) # Check the header is well formatted. @@ -265,27 +262,24 @@ def _write_header(fileobj, header): msg = f"Key-value pairs cannot contain '\\n':\n{out}" raise HeaderError(msg) - if out.count(":") > len(lines) - 1: + if out.count(":") > len(lines): # : only one per line (except the last one which contains END). msg = f"Key-value pairs cannot contain ':':\n{out}" raise HeaderError(msg) + out = header[Field.MAGIC_NUMBER] + b"\n" + out.encode('utf-8') + + # Compute data offset considering the offset string representation + # headers + "file" header + END + \n's + hdr_offset = len(out) + 8 + 3 + 3 + offset_repr = f'{hdr_offset}' + + # Adding the offset may increase one char to the offset repr + hdr_offset += len(f'{hdr_offset + len(offset_repr)}') + # Write header to file. - fileobj.write(asbytes(out)) - - hdr_len_no_offset = len(out) + 5 - # Need to add number of bytes to store offset as decimal string. We - # start with estimate without string, then update if the - # offset-as-decimal-string got longer after adding length of the - # offset string. - new_offset = -1 - old_offset = hdr_len_no_offset - while new_offset != old_offset: - old_offset = new_offset - new_offset = hdr_len_no_offset + len(str(old_offset)) - - fileobj.write(asbytes(str(new_offset) + "\n")) - fileobj.write(asbytes("END\n")) + fileobj.write(out) + fileobj.write(f'\nfile: . {hdr_offset}\nEND\n'.encode('utf-8')) @classmethod def _read_header(cls, fileobj): @@ -320,7 +314,7 @@ def _read_header(cls, fileobj): # Read magic number magic_number = f.read(len(cls.MAGIC_NUMBER)) - if asstr(magic_number) != cls.MAGIC_NUMBER: + if magic_number != cls.MAGIC_NUMBER: raise HeaderError(f"Invalid magic number: {magic_number}") hdr[Field.MAGIC_NUMBER] = magic_number @@ -331,7 +325,7 @@ def _read_header(cls, fileobj): # Read all key-value pairs contained in the header, stop at EOF for n_line, line in enumerate(f, 1): - line = asstr(line).strip() + line = line.decode('utf-8').strip() if not line: # Skip empty lines continue diff --git a/nibabel/streamlines/tests/test_tck.py b/nibabel/streamlines/tests/test_tck.py index 1cdda4b44e..75786c87c6 100644 --- a/nibabel/streamlines/tests/test_tck.py +++ b/nibabel/streamlines/tests/test_tck.py @@ -216,6 +216,34 @@ def test_write_simple_file(self): with pytest.raises(HeaderError): tck.save(tck_file) + def test_write_bigheader_file(self): + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + + # Offset is represented by 2 characters. + tck_file = BytesIO() + tck = TckFile(tractogram) + tck.header['new_entry'] = ' ' * 20 + tck.save(tck_file) + tck_file.seek(0, os.SEEK_SET) + + new_tck = TckFile.load(tck_file) + assert_tractogram_equal(new_tck.tractogram, tractogram) + assert new_tck.header['_offset_data'] == 99 + + # We made the jump, now offset is represented by 3 characters + # and we need to adjust the offset! + tck_file = BytesIO() + tck = TckFile(tractogram) + tck.header['new_entry'] = ' ' * 21 + tck.save(tck_file) + tck_file.seek(0, os.SEEK_SET) + + new_tck = TckFile.load(tck_file) + assert_tractogram_equal(new_tck.tractogram, tractogram) + assert new_tck.header['_offset_data'] == 101 + + def test_load_write_file(self): for fname in [DATA['empty_tck_fname'], DATA['simple_tck_fname']]: