Skip to content

Commit c69a5ba

Browse files
committed
Add docstrings
1 parent cc6b31b commit c69a5ba

File tree

2 files changed

+73
-13
lines changed

2 files changed

+73
-13
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,21 @@ def test_getitem(self):
181181
assert_arrays_equal(sdict[-1][k], v[-1])
182182
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]])
183183

184+
def test_extend(self):
185+
sdict = PerArrayDict(len(DATA['tractogram']),
186+
DATA['data_per_streamline'])
187+
sdict2 = PerArrayDict(len(DATA['tractogram']),
188+
DATA['data_per_streamline'])
189+
190+
sdict += sdict2
191+
assert_equal(len(sdict), len(sdict2))
192+
for k, v in DATA['tractogram'].data_per_streamline.items():
193+
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v)
194+
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v)
195+
196+
# Test incompatible PerArrayDicts.
197+
assert_raises(ValueError, sdict.extend, PerArrayDict())
198+
184199

185200
class TestPerArraySequenceDict(unittest.TestCase):
186201

@@ -233,6 +248,20 @@ def test_getitem(self):
233248
assert_arrays_equal(sdict[-1][k], v[-1])
234249
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]])
235250

251+
def test_extend(self):
252+
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows
253+
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
254+
sdict2 = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
255+
256+
sdict += sdict2
257+
assert_equal(len(sdict), len(sdict2))
258+
for k, v in DATA['tractogram'].data_per_point.items():
259+
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v)
260+
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v)
261+
262+
# Test incompatible PerArrayDicts.
263+
assert_raises(ValueError, sdict.extend, PerArraySequenceDict())
264+
236265

237266
class TestLazyDict(unittest.TestCase):
238267

nibabel/streamlines/tractogram.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,22 @@ def __setitem__(self, key, value):
113113
self.store[key] = value
114114

115115
def extend(self, other):
116-
if len(self) != len(other):
117-
msg = ("Size mismatched between the two PerArrayDict objects."
118-
" This PerArrayDict has {0} elements whereas the other "
119-
" has {1} elements.").format(len(self), len(other))
120-
raise ValueError(msg)
116+
""" Appends the elements of another :class:`PerArrayDict`.
117+
118+
That is, for each entry in this dictionary, we append the elements
119+
coming from the other dictionary at the corresponding entry.
120+
121+
Parameters
122+
----------
123+
other : :class:`PerArrayDict` object
124+
Its data will be appended to the data of this dictionary.
121125
126+
Notes
127+
-----
128+
The entries in both dictionaries must match.
129+
"""
122130
if sorted(self.keys()) != sorted(other.keys()):
123-
msg = ("Key mismatched between the two PerArrayDict objects."
131+
msg = ("Entry mismatched between the two PerArrayDict objects."
124132
" This PerArrayDict contains '{0}' whereas the other "
125133
" contains '{1}'.").format(sorted(self.keys()),
126134
sorted(other.keys()))
@@ -159,12 +167,20 @@ def __setitem__(self, key, value):
159167
self.store[key] = value
160168

161169
def extend(self, other):
162-
if len(self) != len(other):
163-
msg = ("Size mismatched between the two PerArrayDict objects."
164-
" This PerArrayDict has {0} elements whereas the other "
165-
" has {1} elements.").format(len(self), len(other))
166-
raise ValueError(msg)
170+
""" Appends the elements of another :class:`PerArraySequenceDict`.
167171
172+
That is, for each entry in this dictionary, we append the elements
173+
coming from the other dictionary at the corresponding entry.
174+
175+
Parameters
176+
----------
177+
other : :class:`PerArraySequenceDict` object
178+
Its data will be appended to the data of this dictionary.
179+
180+
Notes
181+
-----
182+
The entries in both dictionaries must match.
183+
"""
168184
if sorted(self.keys()) != sorted(other.keys()):
169185
msg = ("Key mismatched between the two PerArrayDict objects."
170186
" This PerArrayDict contains '{0}' whereas the other "
@@ -463,7 +479,21 @@ def to_world(self, lazy=False):
463479
return self.apply_affine(self.affine_to_rasmm, lazy=lazy)
464480

465481
def extend(self, other):
466-
# TODO: Make sure the other tractogram is compatible.
482+
""" Appends the data of another :class:`Tractogram`.
483+
484+
Data that will be appended includes the streamlines and the content
485+
of both dictionaries `data_per_streamline` and `data_per_point`.
486+
487+
Parameters
488+
----------
489+
other : :class:`Tractogram` object
490+
Its data will be appended to the data of this tractogram.
491+
492+
Notes
493+
-----
494+
The entries of `self.data_per_streamline` and `self.data_per_point`
495+
must match those contained in the other tractogram.
496+
"""
467497
self.streamlines.extend(other.streamlines)
468498
self.data_per_streamline += other.data_per_streamline
469499
self.data_per_point += other.data_per_point
@@ -708,7 +738,8 @@ def __getitem__(self, idx):
708738
raise NotImplementedError('LazyTractogram does not support indexing.')
709739

710740
def extend(self, other):
711-
raise NotImplementedError('LazyTractogram does not support concatenation.')
741+
msg = 'LazyTractogram does not support concatenation.'
742+
raise NotImplementedError(msg)
712743

713744
def __iter__(self):
714745
count = 0

0 commit comments

Comments
 (0)