Skip to content

Commit 1642ba0

Browse files
committed
NF: add support for Python operators to ArraySequence
1 parent 3b3131a commit 1642ba0

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,35 @@ def update_seq(self, arr_seq):
5353
arr_seq._lengths = np.array(self.lengths)
5454

5555

56+
def _define_operators(cls):
57+
""" Decorator which adds support for some Python operators. """
58+
def _wrap(cls, op, name=None, inplace=False, unary=False):
59+
name = name or op
60+
if unary:
61+
setattr(cls, name, lambda self: self._op(op))
62+
else:
63+
setattr(cls, name,
64+
lambda self, value: self._op(op, value, inplace=inplace))
65+
66+
for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
67+
"__ifloordiv__", "__itruediv__", "__ior__"]:
68+
_wrap(cls, op, inplace=True)
69+
70+
for op in ["__add__", "__sub__", "__mul__", "__div__",
71+
"__floordiv__", "__truediv__", "__or__"]:
72+
op_ = "__i{}__".format(op.strip("_"))
73+
_wrap(cls, op_, name=op)
74+
75+
for op in ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]:
76+
_wrap(cls, op)
77+
78+
for op in ["__neg__"]:
79+
_wrap(cls, op, unary=True)
80+
81+
return cls
82+
83+
84+
@_define_operators
5685
class ArraySequence(object):
5786
""" Sequence of ndarrays having variable first dimension sizes.
5887
@@ -120,6 +149,23 @@ def data(self):
120149
""" Elements in this array sequence. """
121150
return self._data
122151

152+
def _check_shape(self, arrseq):
153+
""" Check whether this array sequence is compatible with another. """
154+
msg = "cannot perform operation - array sequences have different"
155+
if len(self._lengths) != len(arrseq._lengths):
156+
msg += " lengths: {} vs. {}."
157+
raise ValueError(msg.format(len(self._lengths), len(arrseq._lengths)))
158+
159+
if self.total_nb_rows != arrseq.total_nb_rows:
160+
msg += " amount of data: {} vs. {}."
161+
raise ValueError(msg.format(self.total_nb_rows, arrseq.total_nb_rows))
162+
163+
if self.common_shape != arrseq.common_shape:
164+
msg += " common shape: {} vs. {}."
165+
raise ValueError(msg.format(self.common_shape, arrseq.common_shape))
166+
167+
return True
168+
123169
def _get_next_offset(self):
124170
""" Offset in ``self._data`` at which to write next rowelement """
125171
if len(self._offsets) == 0:
@@ -377,6 +423,37 @@ def __setitem__(self, idx, elements):
377423
for o1, l1, o2, l2 in zip(offsets, lengths, elements._offsets, elements._lengths):
378424
data[o1:o1 + l1] = elements._data[o2:o2 + l2]
379425

426+
def _op(self, op, value=None, inplace=False):
427+
""" Applies some operator to this arraysequence.
428+
429+
This handles both unary and binary operators with a scalar or another
430+
array sequence. Operations are performed directly on the underlying
431+
data, or a copy of it, which depends on the value of `inplace`.
432+
433+
Parameters
434+
----------
435+
op : str
436+
Name of the Python operator (e.g., `"__add__"`).
437+
value : scalar or :class:`ArraySequence`, optional
438+
If None, the operator is assumed to be unary.
439+
Otherwise, that value is used in the binary operation.
440+
inplace: bool, optional
441+
If False, the operation is done on a copy of this array sequence.
442+
Otherwise, this array sequence gets modified directly.
443+
"""
444+
seq = self if inplace else self.copy()
445+
446+
if is_array_sequence(value) and seq._check_shape(value):
447+
for o1, l1, o2, l2 in zip(seq._offsets, seq._lengths, value._offsets, value._lengths):
448+
seq._data[o1:o1 + l1] = getattr(seq._data[o1:o1 + l1], op)(value._data[o2:o2 + l2])
449+
450+
else:
451+
args = [] if value is None else [value] # Dealing with unary and binary ops.
452+
for o1, l1 in zip(seq._offsets, seq._lengths):
453+
seq._data[o1:o1 + l1] = getattr(seq._data[o1:o1 + l1], op)(*args)
454+
455+
return seq
456+
380457
def __iter__(self):
381458
if len(self._lengths) != len(self._offsets):
382459
raise ValueError("ArraySequence object corrupted:"

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,67 @@ def test_arraysequence_getitem(self):
277277
check_arr_seq_view(seq_view, SEQ_DATA['seq'])
278278
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
279279

280+
def test_arraysequence_operators(self):
281+
for op in ["__add__", "__sub__", "__mul__", "__floordiv__", "__truediv__",
282+
"__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]:
283+
# Test math operators with a scalar.
284+
for scalar in [42, 0.5, True]:
285+
seq = getattr(SEQ_DATA['seq'], op)(scalar)
286+
assert_true(seq is not SEQ_DATA['seq'])
287+
check_arr_seq(seq, [getattr(d, op)(scalar) for d in SEQ_DATA['data']])
288+
289+
# Test math operators with another ArraySequence.
290+
seq = getattr(SEQ_DATA['seq'], op)(SEQ_DATA['seq'])
291+
assert_true(seq is not SEQ_DATA['seq'])
292+
check_arr_seq(seq, [getattr(d, op)(d) for d in SEQ_DATA['data']])
293+
294+
# Test math operators with ArraySequence views.
295+
orig = SEQ_DATA['seq'][::2]
296+
seq = getattr(orig, op)(orig)
297+
assert_true(seq is not orig)
298+
check_arr_seq(seq, [getattr(d, op)(d) for d in SEQ_DATA['data'][::2]])
299+
300+
# Test in-place operators.
301+
for op in ["__iadd__", "__isub__", "__imul__", "__ifloordiv__", "__itruediv__"]:
302+
# Test in-place math operators with a scalar.
303+
for scalar in [42, 0.5, True]:
304+
seq = seq_orig = SEQ_DATA['seq'].copy()
305+
seq = getattr(seq, op)(scalar)
306+
assert_true(seq is seq_orig)
307+
check_arr_seq(seq, [getattr(d.copy(), op)(scalar) for d in SEQ_DATA['data']])
308+
309+
# Test in-place math operators with another ArraySequence.
310+
seq = seq_orig = SEQ_DATA['seq'].copy()
311+
seq = getattr(seq, op)(SEQ_DATA['seq'])
312+
assert_true(seq is seq_orig)
313+
check_arr_seq(seq, [getattr(d.copy(), op)(d) for d in SEQ_DATA['data']])
314+
315+
# Test in-place math operators with ArraySequence views.
316+
seq = seq_orig = SEQ_DATA['seq'].copy()[::2]
317+
seq = getattr(seq, op)(seq)
318+
assert_true(seq is seq_orig)
319+
check_arr_seq(seq, [getattr(d.copy(), op)(d) for d in SEQ_DATA['data'][::2]])
320+
321+
# Operations between array sequences of different lengths.
322+
seq = SEQ_DATA['seq'].copy()
323+
assert_raises(ValueError, getattr(seq, op), SEQ_DATA['seq'][::2])
324+
325+
# Operations between array sequences with different amount of data.
326+
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
327+
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
328+
assert_raises(ValueError, getattr(seq1, op), seq2)
329+
330+
# Operations between array sequences with different common shape.
331+
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
332+
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
333+
assert_raises(ValueError, getattr(seq1, op), seq2)
334+
335+
# Unary operators
336+
for op in ["__neg__"]:
337+
seq = getattr(SEQ_DATA['seq'], op)()
338+
assert_true(seq is not SEQ_DATA['seq'])
339+
check_arr_seq(seq, [getattr(d, op)() for d in SEQ_DATA['data']])
340+
280341
def test_arraysequence_setitem(self):
281342
# Set one item
282343
seq = SEQ_DATA['seq'] * 0

0 commit comments

Comments
 (0)