diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index e86cbb5127..71b4bcb3be 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -5,6 +5,8 @@ import numpy as np +from ..deprecated import deprecate_with_version + MEGABYTE = 1024 * 1024 @@ -53,6 +55,37 @@ def update_seq(self, arr_seq): arr_seq._lengths = np.array(self.lengths) +def _define_operators(cls): + """ Decorator which adds support for some Python operators. """ + def _wrap(cls, op, inplace=False, unary=False): + + def fn_unary_op(self): + return self._op(op) + + def fn_binary_op(self, value): + return self._op(op, value, inplace=inplace) + + setattr(cls, op, fn_unary_op if unary else fn_binary_op) + fn = getattr(cls, op) + fn.__name__ = op + fn.__doc__ = getattr(np.ndarray, op).__doc__ + + for op in ["__add__", "__sub__", "__mul__", "__mod__", "__pow__", + "__floordiv__", "__truediv__", "__lshift__", "__rshift__", + "__or__", "__and__", "__xor__"]: + _wrap(cls, op=op, inplace=False) + _wrap(cls, op="__i{}__".format(op.strip("_")), inplace=True) + + for op in ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]: + _wrap(cls, op) + + for op in ["__neg__", "__abs__", "__invert__"]: + _wrap(cls, op, unary=True) + + return cls + + +@_define_operators class ArraySequence(object): """ Sequence of ndarrays having variable first dimension sizes. @@ -116,9 +149,42 @@ def total_nb_rows(self): return np.sum(self._lengths) @property + @deprecate_with_version("'ArraySequence.data' property is deprecated.\n" + "Please use the 'ArraySequence.get_data()' method instead", + '3.0', '4.0') def data(self): """ Elements in this array sequence. """ - return self._data + view = self._data.view() + view.setflags(write=False) + return view + + def get_data(self): + """ Returns a *copy* of the elements in this array sequence. + + Notes + ----- + To modify the data on this array sequence, one can use + in-place mathematical operators (e.g., `seq += ...`) or the use + assignment operator (i.e, `seq[...] = value`). + """ + return self.copy()._data + + def _check_shape(self, arrseq): + """ Check whether this array sequence is compatible with another. """ + msg = "cannot perform operation - array sequences have different" + if len(self._lengths) != len(arrseq._lengths): + msg += " lengths: {} vs. {}." + raise ValueError(msg.format(len(self._lengths), len(arrseq._lengths))) + + if self.total_nb_rows != arrseq.total_nb_rows: + msg += " amount of data: {} vs. {}." + raise ValueError(msg.format(self.total_nb_rows, arrseq.total_nb_rows)) + + if self.common_shape != arrseq.common_shape: + msg += " common shape: {} vs. {}." + raise ValueError(msg.format(self.common_shape, arrseq.common_shape)) + + return True def _get_next_offset(self): """ Offset in ``self._data`` at which to write next rowelement """ @@ -320,7 +386,7 @@ def __getitem__(self, idx): seq._lengths = self._lengths[off_idx] return seq - if isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx): + if isinstance(off_idx, (list, range)) or is_ndarray_of_int_or_bool(off_idx): # Fancy indexing seq._offsets = self._offsets[off_idx] seq._lengths = self._lengths[off_idx] @@ -329,6 +395,116 @@ def __getitem__(self, idx): raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) + def __setitem__(self, idx, elements): + """ Set sequence(s) through standard or advanced numpy indexing. + + Parameters + ---------- + idx : int or slice or list or ndarray + If int, index of the element to retrieve. + If slice, use slicing to retrieve elements. + If list, indices of the elements to retrieve. + If ndarray with dtype int, indices of the elements to retrieve. + If ndarray with dtype bool, only retrieve selected elements. + elements: ndarray or :class:`ArraySequence` + Data that will overwrite selected sequences. + If `idx` is an int, `elements` is expected to be a ndarray. + Otherwise, `elements` is expected a :class:`ArraySequence` object. + """ + if isinstance(idx, (numbers.Integral, np.integer)): + start = self._offsets[idx] + self._data[start:start + self._lengths[idx]] = elements + return + + if isinstance(idx, tuple): + off_idx = idx[0] + data = self._data.__getitem__((slice(None),) + idx[1:]) + else: + off_idx = idx + data = self._data + + if isinstance(off_idx, slice): # Standard list slicing + offsets = self._offsets[off_idx] + lengths = self._lengths[off_idx] + + elif isinstance(off_idx, (list, range)) or is_ndarray_of_int_or_bool(off_idx): + # Fancy indexing + offsets = self._offsets[off_idx] + lengths = self._lengths[off_idx] + + else: + raise TypeError("Index must be either an int, a slice, a list of int" + " or a ndarray of bool! Not " + str(type(idx))) + + if is_array_sequence(elements): + if len(lengths) != len(elements): + msg = "Trying to set {} sequences with {} sequences." + raise ValueError(msg.format(len(lengths), len(elements))) + + if sum(lengths) != elements.total_nb_rows: + msg = "Trying to set {} points with {} points." + raise ValueError(msg.format(sum(lengths), elements.total_nb_rows)) + + for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths): + data[o1:o1 + l1] = elements._data[o2:o2 + l2] + + elif isinstance(elements, numbers.Number): + for o1, l1 in zip(offsets, lengths): + data[o1:o1 + l1] = elements + + else: # Try to iterate over it. + for o1, l1, element in zip(offsets, lengths, elements): + data[o1:o1 + l1] = element + + def _op(self, op, value=None, inplace=False): + """ Applies some operator to this arraysequence. + + This handles both unary and binary operators with a scalar or another + array sequence. Operations are performed directly on the underlying + data, or a copy of it, which depends on the value of `inplace`. + + Parameters + ---------- + op : str + Name of the Python operator (e.g., `"__add__"`). + value : scalar or :class:`ArraySequence`, optional + If None, the operator is assumed to be unary. + Otherwise, that value is used in the binary operation. + inplace: bool, optional + If False, the operation is done on a copy of this array sequence. + Otherwise, this array sequence gets modified directly. + """ + seq = self if inplace else self.copy() + + if is_array_sequence(value) and seq._check_shape(value): + elements = zip(seq._offsets, seq._lengths, + self._offsets, self._lengths, + value._offsets, value._lengths) + + # Change seq.dtype to match the operation resulting type. + o0, l0, o1, l1, o2, l2 = next(elements) + tmp = getattr(self._data[o1:o1 + l1], op)(value._data[o2:o2 + l2]) + seq._data = seq._data.astype(tmp.dtype) + seq._data[o0:o0 + l0] = tmp + + for o0, l0, o1, l1, o2, l2 in elements: + seq._data[o0:o0 + l0] = getattr(self._data[o1:o1 + l1], op)(value._data[o2:o2 + l2]) + + else: + args = [] if value is None else [value] # Dealing with unary and binary ops. + elements = zip(seq._offsets, seq._lengths, self._offsets, self._lengths) + + # Change seq.dtype to match the operation resulting type. + o0, l0, o1, l1 = next(elements) + tmp = getattr(self._data[o1:o1 + l1], op)(*args) + seq._data = seq._data.astype(tmp.dtype) + seq._data[o0:o0 + l0] = tmp + + for o0, l0, o1, l1 in elements: + seq._data[o0:o0 + l0] = getattr(self._data[o1:o1 + l1], op)(*args) + + return seq + def __iter__(self): if len(self._lengths) != len(self._offsets): raise ValueError("ArraySequence object corrupted:" @@ -371,7 +547,7 @@ def load(cls, filename): return seq -def create_arraysequences_from_generator(gen, n): +def create_arraysequences_from_generator(gen, n, buffer_sizes=None): """ Creates :class:`ArraySequence` objects from a generator yielding tuples Parameters @@ -381,8 +557,13 @@ def create_arraysequences_from_generator(gen, n): array sequences. n : int Number of :class:`ArraySequences` object to create. + buffer_sizes : list of float, optional + Sizes (in Mb) for each ArraySequence's buffer. """ - seqs = [ArraySequence() for _ in range(n)] + if buffer_sizes is None: + buffer_sizes = [4] * n + + seqs = [ArraySequence(buffer_size=size) for size in buffer_sizes] for data in gen: for i, seq in enumerate(seqs): if data[i].nbytes > 0: diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 33421f45c7..c92580accb 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -24,7 +24,7 @@ def setup(): def generate_data(nb_arrays, common_shape, rng): - data = [rng.rand(*(rng.randint(3, 20),) + common_shape) + data = [rng.rand(*(rng.randint(3, 20),) + common_shape) * 100 for _ in range(nb_arrays)] return data @@ -228,9 +228,6 @@ def test_arraysequence_getitem(self): for i, e in enumerate(SEQ_DATA['seq']): assert_array_equal(SEQ_DATA['seq'][i], e) - if sys.version_info < (3,): - assert_array_equal(SEQ_DATA['seq'][long(i)], e) - # Get all items using indexing (creates a view). indices = list(range(len(SEQ_DATA['seq']))) seq_view = SEQ_DATA['seq'][indices] @@ -277,6 +274,157 @@ def test_arraysequence_getitem(self): check_arr_seq_view(seq_view, SEQ_DATA['seq']) check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]]) + def test_arraysequence_setitem(self): + # Set one item + seq = SEQ_DATA['seq'] * 0 + for i, e in enumerate(SEQ_DATA['seq']): + seq[i] = e + + check_arr_seq(seq, SEQ_DATA['seq']) + + # Setitem with a scalar. + seq = SEQ_DATA['seq'].copy() + seq[:] = 0 + assert_true(seq._data.sum() == 0) + + # Setitem with a list of ndarray. + seq = SEQ_DATA['seq'] * 0 + seq[:] = SEQ_DATA['data'] + check_arr_seq(seq, SEQ_DATA['data']) + + # Setitem using tuple indexing. + seq = ArraySequence(np.arange(900).reshape((50,6,3))) + seq[:, 0] = 0 + assert_true(seq._data[:, 0].sum() == 0) + + # Setitem using tuple indexing. + seq = ArraySequence(np.arange(900).reshape((50,6,3))) + seq[range(len(seq))] = 0 + assert_true(seq._data.sum() == 0) + + # Setitem of a slice using another slice. + seq = ArraySequence(np.arange(900).reshape((50,6,3))) + seq[0:4] = seq[5:9] + check_arr_seq(seq[0:4], seq[5:9]) + + # Setitem between array sequences with different number of sequences. + seq = ArraySequence(np.arange(900).reshape((50,6,3))) + assert_raises(ValueError, seq.__setitem__, slice(0, 4), seq[5:10]) + + # Setitem between array sequences with different amount of points. + seq1 = ArraySequence(np.arange(10).reshape(5, 2)) + seq2 = ArraySequence(np.arange(15).reshape(5, 3)) + assert_raises(ValueError, seq1.__setitem__, slice(0, 5), seq2) + + # Setitem between array sequences with different common shape. + seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3)) + seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2)) + assert_raises(ValueError, seq1.__setitem__, slice(0, 2), seq2) + + # Invalid index. + assert_raises(TypeError, seq.__setitem__, object(), None) + + def test_arraysequence_operators(self): + # Disable division per zero warnings. + flags = np.seterr(divide='ignore', invalid='ignore') + SCALARS = [42, 0.5, True, -3, 0] + CMP_OPS = ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"] + + seq = SEQ_DATA['seq'].copy() + seq_int = SEQ_DATA['seq'].copy() + seq_int._data = seq_int._data.astype(int) + seq_bool = SEQ_DATA['seq'].copy() > 30 + + ARRSEQS = [seq, seq_int, seq_bool] + VIEWS = [seq[::2], seq_int[::2], seq_bool[::2]] + + def _test_unary(op, arrseq): + orig = arrseq.copy() + seq = getattr(orig, op)() + assert_true(seq is not orig) + check_arr_seq(seq, [getattr(d, op)() for d in orig]) + + def _test_binary(op, arrseq, scalars, seqs, inplace=False): + for scalar in scalars: + orig = arrseq.copy() + seq = getattr(orig, op)(scalar) + assert_true((seq is orig) if inplace else (seq is not orig)) + check_arr_seq(seq, [getattr(e, op)(scalar) for e in arrseq]) + + # Test math operators with another ArraySequence. + for other in seqs: + orig = arrseq.copy() + seq = getattr(orig, op)(other) + assert_true(seq is not SEQ_DATA['seq']) + check_arr_seq(seq, [getattr(e1, op)(e2) for e1, e2 in zip(arrseq, other)]) + + # Operations between array sequences of different lengths. + orig = arrseq.copy() + assert_raises(ValueError, getattr(orig, op), orig[::2]) + + # Operations between array sequences with different amount of data. + seq1 = ArraySequence(np.arange(10).reshape(5, 2)) + seq2 = ArraySequence(np.arange(15).reshape(5, 3)) + assert_raises(ValueError, getattr(seq1, op), seq2) + + # Operations between array sequences with different common shape. + seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3)) + seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2)) + assert_raises(ValueError, getattr(seq1, op), seq2) + + + for op in ["__add__", "__sub__", "__mul__", "__mod__", + "__floordiv__", "__truediv__"] + CMP_OPS: + _test_binary(op, seq, SCALARS, ARRSEQS) + _test_binary(op, seq_int, SCALARS, ARRSEQS) + + # Test math operators with ArraySequence views. + _test_binary(op, seq[::2], SCALARS, VIEWS) + _test_binary(op, seq_int[::2], SCALARS, VIEWS) + + if op in CMP_OPS: + continue + + op = "__i{}__".format(op.strip("_")) + _test_binary(op, seq, SCALARS, ARRSEQS, inplace=True) + + if op == "__itruediv__": + continue # Going to deal with it separately. + + _test_binary(op, seq_int, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True) # int <-- int + assert_raises(TypeError, _test_binary, op, seq_int, [0.5], [], inplace=True) # int <-- float + assert_raises(TypeError, _test_binary, op, seq_int, [], [seq], inplace=True) # int <-- float + + # __pow__ : Integers to negative integer powers are not allowed. + _test_binary("__pow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int]) + _test_binary("__ipow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True) + assert_raises(ValueError, _test_binary, "__pow__", seq_int, [-3], []) + assert_raises(ValueError, _test_binary, "__ipow__", seq_int, [-3], [], inplace=True) + + # __itruediv__ is only valid with float arrseq. + for scalar in SCALARS + ARRSEQS: + assert_raises(TypeError, getattr(seq_int.copy(), "__itruediv__"), scalar) + + # Bitwise operators + for op in ("__lshift__", "__rshift__", "__or__", "__and__", "__xor__"): + _test_binary(op, seq_bool, [42, -3, True, 0], [seq_int, seq_bool, -seq_int]) + assert_raises(TypeError, _test_binary, op, seq_bool, [0.5], []) + assert_raises(TypeError, _test_binary, op, seq, [], [seq]) + + # Unary operators + for op in ["__neg__", "__abs__"]: + _test_unary(op, seq) + _test_unary(op, -seq) + _test_unary(op, seq_int) + _test_unary(op, -seq_int) + + _test_unary("__abs__", seq_bool) + _test_unary("__invert__", seq_bool) + assert_raises(TypeError, _test_unary, "__invert__", seq) + + # Restore flags. + np.seterr(**flags) + def test_arraysequence_repr(self): # Test that calling repr on a ArraySequence object is not falling. repr(SEQ_DATA['seq']) @@ -319,6 +467,15 @@ def test_save_and_load_arraysequence(self): # Make sure we can add new elements to it. loaded_seq.append(SEQ_DATA['data'][0]) + def test_get_data(self): + seq_view = SEQ_DATA['seq'][::2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + + # We make sure the array sequence data does not + # contain more elements than it is supposed to. + data = seq_view.get_data() + assert len(data) < len(seq_view._data) + def test_concatenate(): seq = SEQ_DATA['seq'].copy() # In case there is in-place modification. diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 2f96e56843..2e537c63f2 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -267,6 +267,19 @@ def test_save_complex_file(self): tfile = nib.streamlines.load(filename, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) + def test_save_sliced_tractogram(self): + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + original_tractogram = tractogram.copy() + for ext, cls in FORMATS.items(): + with InTemporaryDirectory(): + filename = 'streamlines' + ext + nib.streamlines.save(tractogram[::2], filename) + tfile = nib.streamlines.load(filename, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram[::2]) + # Make sure original tractogram hasn't changed. + assert_tractogram_equal(tractogram, original_tractogram) + def test_load_unknown_format(self): assert_raises(ValueError, nib.streamlines.load, "") diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 888de0bd49..407f3ef413 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -539,9 +539,6 @@ def test_tractogram_getitem(self): for i, t in enumerate(DATA['tractogram']): assert_tractogram_item_equal(DATA['tractogram'][i], t) - if sys.version_info < (3,): - assert_tractogram_item_equal(DATA['tractogram'][long(i)], t) - # Get one TractogramItem out of two. tractogram_view = DATA['simple_tractogram'][::2] check_tractogram(tractogram_view, DATA['streamlines'][::2]) @@ -688,6 +685,22 @@ def test_tractogram_apply_affine(self): np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) + # Applying the affine to a tractogram that has been indexed or sliced + # shouldn't affect the remaining streamlines. + tractogram = DATA['tractogram'].copy() + transformed_tractogram = tractogram[::2].apply_affine(affine) + assert_true(transformed_tractogram is not tractogram) + check_tractogram(tractogram[::2], + streamlines=[s*scaling for s in DATA['streamlines'][::2]], + data_per_streamline=DATA['tractogram'].data_per_streamline[::2], + data_per_point=DATA['tractogram'].data_per_point[::2]) + + # Remaining streamlines should match the original ones. + check_tractogram(tractogram[1::2], + streamlines=DATA['streamlines'][1::2], + data_per_streamline=DATA['tractogram'].data_per_streamline[1::2], + data_per_point=DATA['tractogram'].data_per_point[1::2]) + # Check that applying an affine and its inverse give us back the # original streamlines. tractogram = DATA['tractogram'].copy() diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index c6687b82aa..3d01d8426e 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -432,11 +432,8 @@ def apply_affine(self, affine, lazy=False): if np.all(affine == np.eye(4)): return self # No transformation. - BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. - for start in range(0, len(self.streamlines.data), BUFFER_SIZE): - end = start + BUFFER_SIZE - pts = self.streamlines._data[start:end] - self.streamlines.data[start:end] = apply_affine(affine, pts) + for i in range(len(self.streamlines)): + self.streamlines[i] = apply_affine(affine, self.streamlines[i]) if self.affine_to_rasmm is not None: # Update the affine that brings back the streamlines to RASmm. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index f67ab1509a..2397a3ff24 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -372,8 +372,23 @@ def _read(): tractogram = LazyTractogram.from_data_func(_read) else: + + # Speed up loading by guessing a suitable buffer size. + with Opener(fileobj) as f: + old_file_position = f.tell() + f.seek(0, os.SEEK_END) + size = f.tell() + f.seek(old_file_position, os.SEEK_SET) + + # Buffer size is in mega bytes. + mbytes = size // (1024 * 1024) + sizes = [mbytes, 4, 4] + if hdr["nb_scalars_per_point"] > 0: + sizes = [mbytes // 2, mbytes // 2, 4] + trk_reader = cls._read(fileobj, hdr) - arr_seqs = create_arraysequences_from_generator(trk_reader, n=3) + arr_seqs = create_arraysequences_from_generator(trk_reader, n=3, + buffer_sizes=sizes) streamlines, scalars, properties = arr_seqs properties = np.asarray(properties) # Actually a 2d array. tractogram = Tractogram(streamlines)