Skip to content

Commit cc6b31b

Browse files
committed
Add Tractogram concatenation
1 parent 6104bd1 commit cc6b31b

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,17 @@ def test_tractogram_to_world(self):
570570
tractogram.affine_to_rasmm = None
571571
assert_raises(ValueError, tractogram.to_world)
572572

573+
def test_tractogram_extend(self):
574+
# Load tractogram that contains some metadata.
575+
t = DATA['tractogram'].copy()
576+
new_t = DATA['tractogram'].copy()
577+
578+
# Double the tractogram.
579+
new_t += t
580+
assert_equal(len(new_t), 2*len(t))
581+
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
582+
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
583+
573584

574585
class TestLazyTractogram(unittest.TestCase):
575586

@@ -641,6 +652,11 @@ def test_lazy_tractogram_getitem(self):
641652
assert_raises(NotImplementedError,
642653
DATA['lazy_tractogram'].__getitem__, 0)
643654

655+
def test_lazy_tractogram_extend(self):
656+
t = DATA['lazy_tractogram'].copy()
657+
new_t = DATA['lazy_tractogram'].copy()
658+
assert_raises(NotImplementedError, new_t.__iadd__, t)
659+
644660
def test_lazy_tractogram_len(self):
645661
modules = [module_tractogram] # Modules for which to catch warnings.
646662
with clear_and_catch_warnings(record=True, modules=modules) as w:

nibabel/streamlines/tractogram.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,28 @@ def __setitem__(self, key, value):
112112

113113
self.store[key] = value
114114

115+
def extend(self, other):
116+
if len(self) != len(other):
117+
msg = ("Size mismatched between the two PerArrayDict objects."
118+
" This PerArrayDict has {0} elements whereas the other "
119+
" has {1} elements.").format(len(self), len(other))
120+
raise ValueError(msg)
121+
122+
if sorted(self.keys()) != sorted(other.keys()):
123+
msg = ("Key mismatched between the two PerArrayDict objects."
124+
" This PerArrayDict contains '{0}' whereas the other "
125+
" contains '{1}'.").format(sorted(self.keys()),
126+
sorted(other.keys()))
127+
raise ValueError(msg)
128+
129+
self.n_rows += other.n_rows
130+
for key in self.keys():
131+
self[key] = np.concatenate([self[key], other[key]])
132+
133+
def __iadd__(self, other):
134+
self.extend(other)
135+
return self
136+
115137

116138
class PerArraySequenceDict(PerArrayDict):
117139
""" Dictionary for which key access can do slicing on the values.
@@ -136,6 +158,28 @@ def __setitem__(self, key, value):
136158

137159
self.store[key] = value
138160

161+
def extend(self, other):
162+
if len(self) != len(other):
163+
msg = ("Size mismatched between the two PerArrayDict objects."
164+
" This PerArrayDict has {0} elements whereas the other "
165+
" has {1} elements.").format(len(self), len(other))
166+
raise ValueError(msg)
167+
168+
if sorted(self.keys()) != sorted(other.keys()):
169+
msg = ("Key mismatched between the two PerArrayDict objects."
170+
" This PerArrayDict contains '{0}' whereas the other "
171+
" contains '{1}'.").format(sorted(self.keys()),
172+
sorted(other.keys()))
173+
raise ValueError(msg)
174+
175+
self.n_rows += other.n_rows
176+
for key in self.keys():
177+
self[key].extend(other[key])
178+
179+
def __iadd__(self, other):
180+
self.extend(other)
181+
return self
182+
139183

140184
class LazyDict(collections.MutableMapping):
141185
""" Dictionary of generator functions.
@@ -418,6 +462,16 @@ def to_world(self, lazy=False):
418462

419463
return self.apply_affine(self.affine_to_rasmm, lazy=lazy)
420464

465+
def extend(self, other):
466+
# TODO: Make sure the other tractogram is compatible.
467+
self.streamlines.extend(other.streamlines)
468+
self.data_per_streamline += other.data_per_streamline
469+
self.data_per_point += other.data_per_point
470+
471+
def __iadd__(self, other):
472+
self.extend(other)
473+
return self
474+
421475

422476
class LazyTractogram(Tractogram):
423477
""" Lazy container for streamlines and their data information.
@@ -653,6 +707,9 @@ def _gen_data():
653707
def __getitem__(self, idx):
654708
raise NotImplementedError('LazyTractogram does not support indexing.')
655709

710+
def extend(self, other):
711+
raise NotImplementedError('LazyTractogram does not support concatenation.')
712+
656713
def __iter__(self):
657714
count = 0
658715
for tractogram_item in self.data:

0 commit comments

Comments
 (0)