Skip to content

Commit df65e94

Browse files
committed
Setitem method for ArraySequence
1 parent 6e65107 commit df65e94

File tree

2 files changed

+43
-46
lines changed

2 files changed

+43
-46
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,12 +434,26 @@ def __setitem__(self, idx, elements):
434434
raise TypeError("Index must be either an int, a slice, a list of int"
435435
" or a ndarray of bool! Not " + str(type(idx)))
436436

437-
if len(lengths) != elements.total_nb_rows:
438-
msg = "Trying to set {} sequences with {} sequences."
439-
raise TypeError(msg.format(len(lengths), elements.total_nb_rows))
437+
if is_array_sequence(elements):
438+
if len(lengths) != len(elements):
439+
msg = "Trying to set {} sequences with {} sequences."
440+
raise ValueError(msg.format(len(lengths), len(elements)))
441+
442+
if sum(lengths) != elements.total_nb_rows:
443+
msg = "Trying to set {} points with {} points."
444+
raise ValueError(msg.format(sum(lengths), elements.total_nb_rows))
445+
446+
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
447+
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
448+
449+
elif isinstance(elements, numbers.Number):
450+
for o1, l1 in zip(offsets, lengths):
451+
data[o1:o1 + l1] = elements
452+
453+
else: # Try to iterate over it.
454+
for o1, l1, data in zip(offsets, lengths, elements):
455+
data[o1:o1 + l1] = data
440456

441-
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
442-
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
443457

444458
def _op(self, op, value=None, inplace=False):
445459
""" Applies some operator to this arraysequence.

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,6 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
375375
# Restore flags.
376376
np.seterr(**flags)
377377

378-
379378
def test_arraysequence_setitem(self):
380379
# Set one item
381380
seq = SEQ_DATA['seq'] * 0
@@ -384,51 +383,35 @@ def test_arraysequence_setitem(self):
384383

385384
check_arr_seq(seq, SEQ_DATA['seq'])
386385

387-
# Get all items using indexing (creates a view).
388-
indices = list(range(len(SEQ_DATA['seq'])))
389-
seq_view = SEQ_DATA['seq'][indices]
390-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
391-
# We took all elements so the view should match the original.
392-
check_arr_seq(seq_view, SEQ_DATA['seq'])
393-
394-
# Get multiple items using ndarray of dtype integer.
395-
for dtype in [np.int8, np.int16, np.int32, np.int64]:
396-
seq_view = SEQ_DATA['seq'][np.array(indices, dtype=dtype)]
397-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
398-
# We took all elements so the view should match the original.
399-
check_arr_seq(seq_view, SEQ_DATA['seq'])
400-
401-
# Get multiple items out of order (creates a view).
402-
SEQ_DATA['rng'].shuffle(indices)
403-
seq_view = SEQ_DATA['seq'][indices]
404-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
405-
check_arr_seq(seq_view, [SEQ_DATA['data'][i] for i in indices])
386+
# Setitem with a scalar.
387+
seq = SEQ_DATA['seq'].copy()
388+
seq[:] = 0
389+
assert_true(seq._data.sum() == 0)
406390

407-
# Get slice (this will create a view).
408-
seq_view = SEQ_DATA['seq'][::2]
409-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
410-
check_arr_seq(seq_view, SEQ_DATA['data'][::2])
391+
# Setitem with a list of ndarray.
392+
seq = SEQ_DATA['seq'].copy()
393+
for i, data in enumerate(SEQ_DATA['data']):
394+
seq[i] = data
395+
check_arr_seq(seq, SEQ_DATA['data'])
411396

412-
# Use advanced indexing with ndarray of data type bool.
413-
selection = np.array([False, True, True, False, True])
414-
seq_view = SEQ_DATA['seq'][selection]
415-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
416-
check_arr_seq(seq_view,
417-
[SEQ_DATA['data'][i]
418-
for i, keep in enumerate(selection) if keep])
397+
# Setting a slice using another slice.
398+
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
399+
seq[0:4] = seq[5:9]
400+
check_arr_seq(seq[0:4], seq[5:9])
419401

420-
# Test invalid indexing
421-
assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc')
402+
# Setting a slice using another slice with more sequences.
403+
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
404+
assert_raises(ValueError, seq.__setitem__, slice(0, 4), seq[5:10])
422405

423-
# Get specific columns.
424-
seq_view = SEQ_DATA['seq'][:, 2]
425-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
426-
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']])
406+
# Setitem between array sequences with different amount of points.
407+
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
408+
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
409+
assert_raises(ValueError, seq1.__setitem__, slice(0, 5), seq2)
427410

428-
# Combining multiple slicing and indexing operations.
429-
seq_view = SEQ_DATA['seq'][::-2][:, 2]
430-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
431-
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
411+
# Setitem between array sequences with different common shape.
412+
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
413+
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
414+
assert_raises(ValueError, seq1.__setitem__, slice(0, 2), seq2)
432415

433416
def test_arraysequence_repr(self):
434417
# Test that calling repr on a ArraySequence object is not falling.

0 commit comments

Comments
 (0)