Skip to content

Commit 1b1f516

Browse files
committed
Setitem method for ArraySequence
1 parent 7400811 commit 1b1f516

File tree

2 files changed

+77
-67
lines changed

2 files changed

+77
-67
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import numpy as np
88

9+
from nibabel.deprecated import deprecate_with_version
10+
11+
912
MEGABYTE = 1024 * 1024
1013

1114

@@ -148,16 +151,15 @@ def total_nb_rows(self):
148151
return np.sum(self._lengths)
149152

150153
@property
154+
@deprecate_with_version("'ArraySequence.data' property is deprecated.\n"
155+
"Please use the 'ArraySequence.get_data()' method instead",
156+
'3.0', '4.0')
151157
def data(self):
152158
""" Elements in this array sequence. """
153-
warnings.warn("The 'ArraySequence.data' property has been deprecated"
154-
" in favor of 'ArraySequence.get_data()'.",
155-
DeprecationWarning,
156-
stacklevel=2)
157159
return self.get_data()
158160

159161
def get_data(self):
160-
""" Returns a copy of the elements in this array sequence.
162+
""" Returns a *copy* of the elements in this array sequence.
161163
162164
Notes
163165
-----
@@ -384,7 +386,7 @@ def __getitem__(self, idx):
384386
seq._lengths = self._lengths[off_idx]
385387
return seq
386388

387-
if isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx):
389+
if isinstance(off_idx, (list, range)) or is_ndarray_of_int_or_bool(off_idx):
388390
# Fancy indexing
389391
seq._offsets = self._offsets[off_idx]
390392
seq._lengths = self._lengths[off_idx]
@@ -425,7 +427,7 @@ def __setitem__(self, idx, elements):
425427
offsets = self._offsets[off_idx]
426428
lengths = self._lengths[off_idx]
427429

428-
elif isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx):
430+
elif isinstance(off_idx, (list, range)) or is_ndarray_of_int_or_bool(off_idx):
429431
# Fancy indexing
430432
offsets = self._offsets[off_idx]
431433
lengths = self._lengths[off_idx]
@@ -434,12 +436,25 @@ def __setitem__(self, idx, elements):
434436
raise TypeError("Index must be either an int, a slice, a list of int"
435437
" or a ndarray of bool! Not " + str(type(idx)))
436438

437-
if len(lengths) != elements.total_nb_rows:
438-
msg = "Trying to set {} sequences with {} sequences."
439-
raise TypeError(msg.format(len(lengths), elements.total_nb_rows))
439+
if is_array_sequence(elements):
440+
if len(lengths) != len(elements):
441+
msg = "Trying to set {} sequences with {} sequences."
442+
raise ValueError(msg.format(len(lengths), len(elements)))
443+
444+
if sum(lengths) != elements.total_nb_rows:
445+
msg = "Trying to set {} points with {} points."
446+
raise ValueError(msg.format(sum(lengths), elements.total_nb_rows))
447+
448+
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
449+
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
450+
451+
elif isinstance(elements, numbers.Number):
452+
for o1, l1 in zip(offsets, lengths):
453+
data[o1:o1 + l1] = elements
440454

441-
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
442-
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
455+
else: # Try to iterate over it.
456+
for o1, l1, element in zip(offsets, lengths, elements):
457+
data[o1:o1 + l1] = element
443458

444459
def _op(self, op, value=None, inplace=False):
445460
""" Applies some operator to this arraysequence.

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,56 @@ def test_arraysequence_getitem(self):
274274
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
275275
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
276276

277+
def test_arraysequence_setitem(self):
278+
# Set one item
279+
seq = SEQ_DATA['seq'] * 0
280+
for i, e in enumerate(SEQ_DATA['seq']):
281+
seq[i] = e
282+
283+
check_arr_seq(seq, SEQ_DATA['seq'])
284+
285+
# Setitem with a scalar.
286+
seq = SEQ_DATA['seq'].copy()
287+
seq[:] = 0
288+
assert_true(seq._data.sum() == 0)
289+
290+
# Setitem with a list of ndarray.
291+
seq = SEQ_DATA['seq'] * 0
292+
seq[:] = SEQ_DATA['data']
293+
check_arr_seq(seq, SEQ_DATA['data'])
294+
295+
# Setitem using tuple indexing.
296+
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
297+
seq[:, 0] = 0
298+
assert_true(seq._data[:, 0].sum() == 0)
299+
300+
# Setitem using tuple indexing.
301+
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
302+
seq[range(len(seq))] = 0
303+
assert_true(seq._data.sum() == 0)
304+
305+
# Setitem of a slice using another slice.
306+
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
307+
seq[0:4] = seq[5:9]
308+
check_arr_seq(seq[0:4], seq[5:9])
309+
310+
# Setitem between array sequences with different number of sequences.
311+
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
312+
assert_raises(ValueError, seq.__setitem__, slice(0, 4), seq[5:10])
313+
314+
# Setitem between array sequences with different amount of points.
315+
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
316+
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
317+
assert_raises(ValueError, seq1.__setitem__, slice(0, 5), seq2)
318+
319+
# Setitem between array sequences with different common shape.
320+
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
321+
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
322+
assert_raises(ValueError, seq1.__setitem__, slice(0, 2), seq2)
323+
324+
# Invalid index.
325+
assert_raises(TypeError, seq.__setitem__, object(), None)
326+
277327
def test_arraysequence_operators(self):
278328
# Disable division per zero warnings.
279329
flags = np.seterr(divide='ignore', invalid='ignore')
@@ -375,61 +425,6 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
375425
# Restore flags.
376426
np.seterr(**flags)
377427

378-
379-
def test_arraysequence_setitem(self):
380-
# Set one item
381-
seq = SEQ_DATA['seq'] * 0
382-
for i, e in enumerate(SEQ_DATA['seq']):
383-
seq[i] = e
384-
385-
check_arr_seq(seq, SEQ_DATA['seq'])
386-
387-
# Get all items using indexing (creates a view).
388-
indices = list(range(len(SEQ_DATA['seq'])))
389-
seq_view = SEQ_DATA['seq'][indices]
390-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
391-
# We took all elements so the view should match the original.
392-
check_arr_seq(seq_view, SEQ_DATA['seq'])
393-
394-
# Get multiple items using ndarray of dtype integer.
395-
for dtype in [np.int8, np.int16, np.int32, np.int64]:
396-
seq_view = SEQ_DATA['seq'][np.array(indices, dtype=dtype)]
397-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
398-
# We took all elements so the view should match the original.
399-
check_arr_seq(seq_view, SEQ_DATA['seq'])
400-
401-
# Get multiple items out of order (creates a view).
402-
SEQ_DATA['rng'].shuffle(indices)
403-
seq_view = SEQ_DATA['seq'][indices]
404-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
405-
check_arr_seq(seq_view, [SEQ_DATA['data'][i] for i in indices])
406-
407-
# Get slice (this will create a view).
408-
seq_view = SEQ_DATA['seq'][::2]
409-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
410-
check_arr_seq(seq_view, SEQ_DATA['data'][::2])
411-
412-
# Use advanced indexing with ndarray of data type bool.
413-
selection = np.array([False, True, True, False, True])
414-
seq_view = SEQ_DATA['seq'][selection]
415-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
416-
check_arr_seq(seq_view,
417-
[SEQ_DATA['data'][i]
418-
for i, keep in enumerate(selection) if keep])
419-
420-
# Test invalid indexing
421-
assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc')
422-
423-
# Get specific columns.
424-
seq_view = SEQ_DATA['seq'][:, 2]
425-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
426-
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']])
427-
428-
# Combining multiple slicing and indexing operations.
429-
seq_view = SEQ_DATA['seq'][::-2][:, 2]
430-
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
431-
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
432-
433428
def test_arraysequence_repr(self):
434429
# Test that calling repr on a ArraySequence object is not falling.
435430
repr(SEQ_DATA['seq'])

0 commit comments

Comments
 (0)