Skip to content

Commit 84dd196

Browse files
committed
RF+TST: add tests for en/decoding value in fields
Add tests for encoding / decoding numerical values in byte string fields. Refactor encoding. Update comments / docstrings to note new encoding with numbers as ASCII strings.
1 parent 959a1c7 commit 84dd196

File tree

2 files changed

+55
-28
lines changed

2 files changed

+55
-28
lines changed

nibabel/streamlines/tests/test_trk.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..tractogram_file import HeaderError, HeaderWarning
1717

1818
from .. import trk as trk_module
19-
from ..trk import TrkFile
19+
from ..trk import TrkFile, encode_value_in_name, decode_value_from_name
2020
from ..header import Field
2121

2222
DATA = {}
@@ -468,3 +468,31 @@ def test_header_read_restore(self):
468468
assert_arr_dict_equal(TrkFile._read_header(bio), hdr_from_fname)
469469
# Check fileobject file position has not changed
470470
assert_equal(bio.tell(), hdr_pos)
471+
472+
473+
def test_encode_names():
474+
# Test function for encoding numbers into property names
475+
b0 = b'\x00'
476+
assert_equal(encode_value_in_name(0, 'foo', 10),
477+
b'foo' + b0 * 7)
478+
assert_equal(encode_value_in_name(1, 'foo', 10),
479+
b'foo' + b0 * 7)
480+
assert_equal(encode_value_in_name(8, 'foo', 10),
481+
b'foo' + b0 + b'8' + b0 * 5)
482+
assert_equal(encode_value_in_name(40, 'foobar', 10),
483+
b'foobar' + b0 + b'40' + b0)
484+
assert_equal(encode_value_in_name(1, 'foobarbazz', 10), b'foobarbazz')
485+
assert_raises(ValueError, encode_value_in_name, 1, 'foobarbazzz', 10)
486+
assert_raises(ValueError, encode_value_in_name, 2, 'foobarbaz', 10)
487+
assert_equal(encode_value_in_name(2, 'foobarba', 10), b'foobarba\x002')
488+
489+
490+
def test_decode_names():
491+
# Test function for decoding name string into name, number
492+
b0 = b'\x00'
493+
assert_equal(decode_value_from_name(b''), ('', 0))
494+
assert_equal(decode_value_from_name(b'foo' + b0 * 7), ('foo', 1))
495+
assert_equal(decode_value_from_name(b'foo\x008' + b0 * 5), ('foo', 8))
496+
assert_equal(decode_value_from_name(b'foobar\x0010\x00'), ('foobar', 10))
497+
assert_raises(ValueError, decode_value_from_name, b'foobar\x0010\x01')
498+
assert_raises(HeaderError, decode_value_from_name, b'foo\x0010\x00111')

nibabel/streamlines/trk.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -129,48 +129,46 @@ def get_affine_rasmm_to_trackvis(header):
129129

130130

131131
def encode_value_in_name(value, name, max_name_len=20):
132-
""" Encodes a value in the last bytes of a string.
132+
""" Return `name` as fixed-length string, appending `value` as string.
133133
134-
If `value` is one, then there is no encoding and the last bytes
135-
are left untouched. Otherwise, a \x00 byte is added after `name`
136-
and followed by the ascii represensation of the value.
134+
Form output from `name` if `value <= 1` else `name` + ``\x00`` +
135+
str(value).
137136
138-
This function also verifies that the length of name is less
139-
than `max_name_len`.
137+
Return output as fixed length string length `max_name_len`, padded with
138+
``\x00``.
139+
140+
This function also verifies that the modified length of name is less than
141+
`max_name_len`.
140142
141143
Parameters
142144
----------
143-
value : byte
144-
Integer value between 0 and 255 to encode.
145-
name : bytes
146-
Name in which the last two bytes will serve to encode `value`.
145+
value : int
146+
Integer value to encode.
147+
name : str
148+
Name to which we may append an ascii / latin-1 representation of
149+
`value`.
147150
max_name_len : int, optional
148-
Maximum length name can have.
151+
Maximum length of byte string that output can have.
149152
150153
Returns
151154
-------
152155
encoded_name : bytes
153-
Name containing the encoded value.
156+
Name maybe followed by ``\x00`` and ascii / latin-1 representation of
157+
`value`, padded with ``\x00`` bytes.
154158
"""
155-
156159
if len(name) > max_name_len:
157160
msg = ("Data information named '{0}' is too long"
158161
" (max {1} characters.)").format(name, max_name_len)
159162
raise ValueError(msg)
160-
elif value > 1 and len(name) + len(str(value)) + 1 > max_name_len:
163+
encoded_name = name if value <= 1 else name + '\x00' + str(value)
164+
if len(encoded_name) > max_name_len:
161165
msg = ("Data information named '{0}' is too long (need to be less"
162166
" than {1} characters when storing more than one value"
163167
" for a given data information."
164168
).format(name, max_name_len - (len(str(value)) + 1))
165169
raise ValueError(msg)
166-
167-
encoded_name = name
168-
if value > 1:
169-
# Store the name followed by \x00 and the `value` (in ascii).
170-
encoded_name += '\x00' + str(value)
171-
172-
encoded_name = encoded_name.ljust(max_name_len, '\x00')
173-
return encoded_name
170+
# Fill to the end with zeros
171+
return encoded_name.ljust(max_name_len, '\x00').encode('latin1')
174172

175173

176174
def decode_value_from_name(encoded_name):
@@ -388,7 +386,7 @@ def _read():
388386
return cls(tractogram, header=hdr)
389387

390388
def save(self, fileobj):
391-
""" Saves tractogram to a file-like object using TRK format.
389+
""" Save tractogram to a filename or file-like object using TRK format.
392390
393391
Parameters
394392
----------
@@ -420,6 +418,7 @@ def save(self, fileobj):
420418
# Keep track of the beginning of the header.
421419
beginning = f.tell()
422420

421+
# Write temporary header that we will update at the end
423422
f.write(header.tostring())
424423

425424
i4_dtype = np.dtype("<i4") # Always save in little-endian.
@@ -449,8 +448,8 @@ def save(self, fileobj):
449448
property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE,
450449
dtype='S20')
451450
for i, name in enumerate(data_for_streamline_keys):
452-
# Use the last two bytes of the name to store the number of
453-
# values associated to this data_for_streamline.
451+
# Append number of values as ascii to zero-terminated name
452+
# to encode number of values into trackvis name.
454453
nb_values = data_for_streamline[name].shape[-1]
455454
property_name[i] = encode_value_in_name(nb_values, name)
456455
header['property_name'][:] = property_name
@@ -466,8 +465,8 @@ def save(self, fileobj):
466465
data_for_points_keys = sorted(data_for_points.keys())
467466
scalar_name = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20')
468467
for i, name in enumerate(data_for_points_keys):
469-
# Use the last two bytes of the name to store the number of
470-
# values associated to this data_for_streamline.
468+
# Append number of values as ascii to zero-terminated name
469+
# to encode number of values into trackvis name.
471470
nb_values = data_for_points[name].shape[-1]
472471
scalar_name[i] = encode_value_in_name(nb_values, name)
473472
header['scalar_name'][:] = scalar_name

0 commit comments

Comments
 (0)