Skip to content

Commit 6cf363d

Browse files
authored
Merge pull request #811 from MarcCote/ref_arrseq_data
Fixing ArraySequence functionalities
2 parents 437073c + d21a980 commit 6cf363d

File tree

6 files changed

+393
-17
lines changed

6 files changed

+393
-17
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 185 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numpy as np
77

8+
from ..deprecated import deprecate_with_version
9+
810
MEGABYTE = 1024 * 1024
911

1012

@@ -53,6 +55,37 @@ def update_seq(self, arr_seq):
5355
arr_seq._lengths = np.array(self.lengths)
5456

5557

58+
def _define_operators(cls):
59+
""" Decorator which adds support for some Python operators. """
60+
def _wrap(cls, op, inplace=False, unary=False):
61+
62+
def fn_unary_op(self):
63+
return self._op(op)
64+
65+
def fn_binary_op(self, value):
66+
return self._op(op, value, inplace=inplace)
67+
68+
setattr(cls, op, fn_unary_op if unary else fn_binary_op)
69+
fn = getattr(cls, op)
70+
fn.__name__ = op
71+
fn.__doc__ = getattr(np.ndarray, op).__doc__
72+
73+
for op in ["__add__", "__sub__", "__mul__", "__mod__", "__pow__",
74+
"__floordiv__", "__truediv__", "__lshift__", "__rshift__",
75+
"__or__", "__and__", "__xor__"]:
76+
_wrap(cls, op=op, inplace=False)
77+
_wrap(cls, op="__i{}__".format(op.strip("_")), inplace=True)
78+
79+
for op in ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]:
80+
_wrap(cls, op)
81+
82+
for op in ["__neg__", "__abs__", "__invert__"]:
83+
_wrap(cls, op, unary=True)
84+
85+
return cls
86+
87+
88+
@_define_operators
5689
class ArraySequence(object):
5790
""" Sequence of ndarrays having variable first dimension sizes.
5891
@@ -116,9 +149,42 @@ def total_nb_rows(self):
116149
return np.sum(self._lengths)
117150

118151
@property
152+
@deprecate_with_version("'ArraySequence.data' property is deprecated.\n"
153+
"Please use the 'ArraySequence.get_data()' method instead",
154+
'3.0', '4.0')
119155
def data(self):
120156
""" Elements in this array sequence. """
121-
return self._data
157+
view = self._data.view()
158+
view.setflags(write=False)
159+
return view
160+
161+
def get_data(self):
162+
""" Returns a *copy* of the elements in this array sequence.
163+
164+
Notes
165+
-----
166+
To modify the data on this array sequence, one can use
167+
in-place mathematical operators (e.g., `seq += ...`) or the use
168+
assignment operator (i.e, `seq[...] = value`).
169+
"""
170+
return self.copy()._data
171+
172+
def _check_shape(self, arrseq):
173+
""" Check whether this array sequence is compatible with another. """
174+
msg = "cannot perform operation - array sequences have different"
175+
if len(self._lengths) != len(arrseq._lengths):
176+
msg += " lengths: {} vs. {}."
177+
raise ValueError(msg.format(len(self._lengths), len(arrseq._lengths)))
178+
179+
if self.total_nb_rows != arrseq.total_nb_rows:
180+
msg += " amount of data: {} vs. {}."
181+
raise ValueError(msg.format(self.total_nb_rows, arrseq.total_nb_rows))
182+
183+
if self.common_shape != arrseq.common_shape:
184+
msg += " common shape: {} vs. {}."
185+
raise ValueError(msg.format(self.common_shape, arrseq.common_shape))
186+
187+
return True
122188

123189
def _get_next_offset(self):
124190
""" Offset in ``self._data`` at which to write next rowelement """
@@ -320,7 +386,7 @@ def __getitem__(self, idx):
320386
seq._lengths = self._lengths[off_idx]
321387
return seq
322388

323-
if isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx):
389+
if isinstance(off_idx, (list, range)) or is_ndarray_of_int_or_bool(off_idx):
324390
# Fancy indexing
325391
seq._offsets = self._offsets[off_idx]
326392
seq._lengths = self._lengths[off_idx]
@@ -329,6 +395,116 @@ def __getitem__(self, idx):
329395
raise TypeError("Index must be either an int, a slice, a list of int"
330396
" or a ndarray of bool! Not " + str(type(idx)))
331397

398+
def __setitem__(self, idx, elements):
399+
""" Set sequence(s) through standard or advanced numpy indexing.
400+
401+
Parameters
402+
----------
403+
idx : int or slice or list or ndarray
404+
If int, index of the element to retrieve.
405+
If slice, use slicing to retrieve elements.
406+
If list, indices of the elements to retrieve.
407+
If ndarray with dtype int, indices of the elements to retrieve.
408+
If ndarray with dtype bool, only retrieve selected elements.
409+
elements: ndarray or :class:`ArraySequence`
410+
Data that will overwrite selected sequences.
411+
If `idx` is an int, `elements` is expected to be a ndarray.
412+
Otherwise, `elements` is expected a :class:`ArraySequence` object.
413+
"""
414+
if isinstance(idx, (numbers.Integral, np.integer)):
415+
start = self._offsets[idx]
416+
self._data[start:start + self._lengths[idx]] = elements
417+
return
418+
419+
if isinstance(idx, tuple):
420+
off_idx = idx[0]
421+
data = self._data.__getitem__((slice(None),) + idx[1:])
422+
else:
423+
off_idx = idx
424+
data = self._data
425+
426+
if isinstance(off_idx, slice): # Standard list slicing
427+
offsets = self._offsets[off_idx]
428+
lengths = self._lengths[off_idx]
429+
430+
elif isinstance(off_idx, (list, range)) or is_ndarray_of_int_or_bool(off_idx):
431+
# Fancy indexing
432+
offsets = self._offsets[off_idx]
433+
lengths = self._lengths[off_idx]
434+
435+
else:
436+
raise TypeError("Index must be either an int, a slice, a list of int"
437+
" or a ndarray of bool! Not " + str(type(idx)))
438+
439+
if is_array_sequence(elements):
440+
if len(lengths) != len(elements):
441+
msg = "Trying to set {} sequences with {} sequences."
442+
raise ValueError(msg.format(len(lengths), len(elements)))
443+
444+
if sum(lengths) != elements.total_nb_rows:
445+
msg = "Trying to set {} points with {} points."
446+
raise ValueError(msg.format(sum(lengths), elements.total_nb_rows))
447+
448+
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
449+
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
450+
451+
elif isinstance(elements, numbers.Number):
452+
for o1, l1 in zip(offsets, lengths):
453+
data[o1:o1 + l1] = elements
454+
455+
else: # Try to iterate over it.
456+
for o1, l1, element in zip(offsets, lengths, elements):
457+
data[o1:o1 + l1] = element
458+
459+
def _op(self, op, value=None, inplace=False):
460+
""" Applies some operator to this arraysequence.
461+
462+
This handles both unary and binary operators with a scalar or another
463+
array sequence. Operations are performed directly on the underlying
464+
data, or a copy of it, which depends on the value of `inplace`.
465+
466+
Parameters
467+
----------
468+
op : str
469+
Name of the Python operator (e.g., `"__add__"`).
470+
value : scalar or :class:`ArraySequence`, optional
471+
If None, the operator is assumed to be unary.
472+
Otherwise, that value is used in the binary operation.
473+
inplace: bool, optional
474+
If False, the operation is done on a copy of this array sequence.
475+
Otherwise, this array sequence gets modified directly.
476+
"""
477+
seq = self if inplace else self.copy()
478+
479+
if is_array_sequence(value) and seq._check_shape(value):
480+
elements = zip(seq._offsets, seq._lengths,
481+
self._offsets, self._lengths,
482+
value._offsets, value._lengths)
483+
484+
# Change seq.dtype to match the operation resulting type.
485+
o0, l0, o1, l1, o2, l2 = next(elements)
486+
tmp = getattr(self._data[o1:o1 + l1], op)(value._data[o2:o2 + l2])
487+
seq._data = seq._data.astype(tmp.dtype)
488+
seq._data[o0:o0 + l0] = tmp
489+
490+
for o0, l0, o1, l1, o2, l2 in elements:
491+
seq._data[o0:o0 + l0] = getattr(self._data[o1:o1 + l1], op)(value._data[o2:o2 + l2])
492+
493+
else:
494+
args = [] if value is None else [value] # Dealing with unary and binary ops.
495+
elements = zip(seq._offsets, seq._lengths, self._offsets, self._lengths)
496+
497+
# Change seq.dtype to match the operation resulting type.
498+
o0, l0, o1, l1 = next(elements)
499+
tmp = getattr(self._data[o1:o1 + l1], op)(*args)
500+
seq._data = seq._data.astype(tmp.dtype)
501+
seq._data[o0:o0 + l0] = tmp
502+
503+
for o0, l0, o1, l1 in elements:
504+
seq._data[o0:o0 + l0] = getattr(self._data[o1:o1 + l1], op)(*args)
505+
506+
return seq
507+
332508
def __iter__(self):
333509
if len(self._lengths) != len(self._offsets):
334510
raise ValueError("ArraySequence object corrupted:"
@@ -371,7 +547,7 @@ def load(cls, filename):
371547
return seq
372548

373549

374-
def create_arraysequences_from_generator(gen, n):
550+
def create_arraysequences_from_generator(gen, n, buffer_sizes=None):
375551
""" Creates :class:`ArraySequence` objects from a generator yielding tuples
376552
377553
Parameters
@@ -381,8 +557,13 @@ def create_arraysequences_from_generator(gen, n):
381557
array sequences.
382558
n : int
383559
Number of :class:`ArraySequences` object to create.
560+
buffer_sizes : list of float, optional
561+
Sizes (in Mb) for each ArraySequence's buffer.
384562
"""
385-
seqs = [ArraySequence() for _ in range(n)]
563+
if buffer_sizes is None:
564+
buffer_sizes = [4] * n
565+
566+
seqs = [ArraySequence(buffer_size=size) for size in buffer_sizes]
386567
for data in gen:
387568
for i, seq in enumerate(seqs):
388569
if data[i].nbytes > 0:

0 commit comments

Comments
 (0)