Skip to content

Commit 3b3131a

Browse files
committed
BF: only apply affine to selected streamlines
1 parent ddf2683 commit 3b3131a

File tree

5 files changed

+140
-5
lines changed

5 files changed

+140
-5
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,54 @@ def __getitem__(self, idx):
329329
raise TypeError("Index must be either an int, a slice, a list of int"
330330
" or a ndarray of bool! Not " + str(type(idx)))
331331

332+
def __setitem__(self, idx, elements):
333+
""" Set sequence(s) through standard or advanced numpy indexing.
334+
335+
Parameters
336+
----------
337+
idx : int or slice or list or ndarray
338+
If int, index of the element to retrieve.
339+
If slice, use slicing to retrieve elements.
340+
If list, indices of the elements to retrieve.
341+
If ndarray with dtype int, indices of the elements to retrieve.
342+
If ndarray with dtype bool, only retrieve selected elements.
343+
elements: ndarray or :class:`ArraySequence`
344+
Data that will overwrite selected sequences.
345+
If `idx` is an int, `elements` is expected to be a ndarray.
346+
Otherwise, `elements` is expected a :class:`ArraySequence` object.
347+
"""
348+
if isinstance(idx, (numbers.Integral, np.integer)):
349+
start = self._offsets[idx]
350+
self._data[start:start + self._lengths[idx]] = elements
351+
return
352+
353+
if isinstance(idx, tuple):
354+
off_idx = idx[0]
355+
data = self._data.__getitem__((slice(None),) + idx[1:])
356+
else:
357+
off_idx = idx
358+
data = self._data
359+
360+
if isinstance(off_idx, slice): # Standard list slicing
361+
offsets = self._offsets[off_idx]
362+
lengths = self._lengths[off_idx]
363+
364+
elif isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx):
365+
# Fancy indexing
366+
offsets = self._offsets[off_idx]
367+
lengths = self._lengths[off_idx]
368+
369+
else:
370+
raise TypeError("Index must be either an int, a slice, a list of int"
371+
" or a ndarray of bool! Not " + str(type(idx)))
372+
373+
if len(lengths) != elements.total_nb_rows:
374+
msg = "Trying to set {} sequences with {} sequences."
375+
raise TypeError(msg.format(len(lengths), elements.total_nb_rows))
376+
377+
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
378+
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
379+
332380
def __iter__(self):
333381
if len(self._lengths) != len(self._offsets):
334382
raise ValueError("ArraySequence object corrupted:"

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,67 @@ def test_arraysequence_getitem(self):
277277
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
278278
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
279279

280+
def test_arraysequence_setitem(self):
281+
# Set one item
282+
seq = SEQ_DATA['seq'] * 0
283+
for i, e in enumerate(SEQ_DATA['seq']):
284+
seq[i] = e
285+
286+
check_arr_seq(seq, SEQ_DATA['seq'])
287+
288+
if sys.version_info < (3,):
289+
seq = ArraySequence(SEQ_DATA['seq'] * 0)
290+
for i, e in enumerate(SEQ_DATA['seq']):
291+
seq[long(i)] = e
292+
293+
check_arr_seq(seq, SEQ_DATA['seq'])
294+
295+
# Get all items using indexing (creates a view).
296+
indices = list(range(len(SEQ_DATA['seq'])))
297+
seq_view = SEQ_DATA['seq'][indices]
298+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
299+
# We took all elements so the view should match the original.
300+
check_arr_seq(seq_view, SEQ_DATA['seq'])
301+
302+
# Get multiple items using ndarray of dtype integer.
303+
for dtype in [np.int8, np.int16, np.int32, np.int64]:
304+
seq_view = SEQ_DATA['seq'][np.array(indices, dtype=dtype)]
305+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
306+
# We took all elements so the view should match the original.
307+
check_arr_seq(seq_view, SEQ_DATA['seq'])
308+
309+
# Get multiple items out of order (creates a view).
310+
SEQ_DATA['rng'].shuffle(indices)
311+
seq_view = SEQ_DATA['seq'][indices]
312+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
313+
check_arr_seq(seq_view, [SEQ_DATA['data'][i] for i in indices])
314+
315+
# Get slice (this will create a view).
316+
seq_view = SEQ_DATA['seq'][::2]
317+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
318+
check_arr_seq(seq_view, SEQ_DATA['data'][::2])
319+
320+
# Use advanced indexing with ndarray of data type bool.
321+
selection = np.array([False, True, True, False, True])
322+
seq_view = SEQ_DATA['seq'][selection]
323+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
324+
check_arr_seq(seq_view,
325+
[SEQ_DATA['data'][i]
326+
for i, keep in enumerate(selection) if keep])
327+
328+
# Test invalid indexing
329+
assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc')
330+
331+
# Get specific columns.
332+
seq_view = SEQ_DATA['seq'][:, 2]
333+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
334+
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']])
335+
336+
# Combining multiple slicing and indexing operations.
337+
seq_view = SEQ_DATA['seq'][::-2][:, 2]
338+
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
339+
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
340+
280341
def test_arraysequence_repr(self):
281342
# Test that calling repr on a ArraySequence object is not falling.
282343
repr(SEQ_DATA['seq'])

nibabel/streamlines/tests/test_streamlines.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,19 @@ def test_save_complex_file(self):
267267
tfile = nib.streamlines.load(filename, lazy_load=False)
268268
assert_tractogram_equal(tfile.tractogram, tractogram)
269269

270+
def test_save_sliced_tractogram(self):
271+
tractogram = Tractogram(DATA['streamlines'],
272+
affine_to_rasmm=np.eye(4))
273+
original_tractogram = tractogram.copy()
274+
for ext, cls in FORMATS.items():
275+
with InTemporaryDirectory():
276+
filename = 'streamlines' + ext
277+
nib.streamlines.save(tractogram[::2], filename)
278+
tfile = nib.streamlines.load(filename, lazy_load=False)
279+
assert_tractogram_equal(tfile.tractogram, tractogram[::2])
280+
# Make sure original tractogram hasn't changed.
281+
assert_tractogram_equal(tractogram, original_tractogram)
282+
270283
def test_load_unknown_format(self):
271284
assert_raises(ValueError, nib.streamlines.load, "")
272285

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,22 @@ def test_tractogram_apply_affine(self):
688688
np.dot(np.eye(4), np.dot(np.linalg.inv(affine),
689689
np.linalg.inv(affine))))
690690

691+
# Applying the affine to a tractogram that has been indexed or sliced
692+
# shouldn't affect the remaining streamlines.
693+
tractogram = DATA['tractogram'].copy()
694+
transformed_tractogram = tractogram[::2].apply_affine(affine)
695+
assert_true(transformed_tractogram is not tractogram)
696+
check_tractogram(tractogram[::2],
697+
streamlines=[s*scaling for s in DATA['streamlines'][::2]],
698+
data_per_streamline=DATA['tractogram'].data_per_streamline[::2],
699+
data_per_point=DATA['tractogram'].data_per_point[::2])
700+
701+
# Remaining streamlines should match the original ones.
702+
check_tractogram(tractogram[1::2],
703+
streamlines=DATA['streamlines'][1::2],
704+
data_per_streamline=DATA['tractogram'].data_per_streamline[1::2],
705+
data_per_point=DATA['tractogram'].data_per_point[1::2])
706+
691707
# Check that applying an affine and its inverse give us back the
692708
# original streamlines.
693709
tractogram = DATA['tractogram'].copy()

nibabel/streamlines/tractogram.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,8 @@ def apply_affine(self, affine, lazy=False):
432432
if np.all(affine == np.eye(4)):
433433
return self # No transformation.
434434

435-
BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3.
436-
for start in range(0, len(self.streamlines.data), BUFFER_SIZE):
437-
end = start + BUFFER_SIZE
438-
pts = self.streamlines._data[start:end]
439-
self.streamlines.data[start:end] = apply_affine(affine, pts)
435+
for i in range(len(self.streamlines)):
436+
self.streamlines[i] = apply_affine(affine, self.streamlines[i])
440437

441438
if self.affine_to_rasmm is not None:
442439
# Update the affine that brings back the streamlines to RASmm.

0 commit comments

Comments
 (0)