Skip to content

Commit 7400811

Browse files
committed
Set arrseq dtype according to the arithmetic operators.
Remove Python version checks. Refactor unit tests for arithmetic operators.
1 parent 5345631 commit 7400811

File tree

3 files changed

+123
-75
lines changed

3 files changed

+123
-75
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,29 @@ def update_seq(self, arr_seq):
5656

5757
def _define_operators(cls):
5858
""" Decorator which adds support for some Python operators. """
59-
def _wrap(cls, op, name=None, inplace=False, unary=False):
60-
name = name or op
61-
if unary:
62-
setattr(cls, name, lambda self: self._op(op))
63-
else:
64-
setattr(cls, name,
65-
lambda self, value: self._op(op, value, inplace=inplace))
59+
def _wrap(cls, op, inplace=False, unary=False):
60+
61+
def fn_unary_op(self):
62+
return self._op(op)
6663

67-
for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
68-
"__ifloordiv__", "__itruediv__", "__ior__"]:
69-
_wrap(cls, op, inplace=True)
64+
def fn_binary_op(self, value):
65+
return self._op(op, value, inplace=inplace)
7066

71-
for op in ["__add__", "__sub__", "__mul__", "__div__",
72-
"__floordiv__", "__truediv__", "__or__"]:
73-
op_ = "__i{}__".format(op.strip("_"))
74-
_wrap(cls, op_, name=op)
67+
setattr(cls, op, fn_unary_op if unary else fn_binary_op)
68+
fn = getattr(cls, op)
69+
fn.__name__ = op
70+
fn.__doc__ = getattr(np.ndarray, op).__doc__
71+
72+
for op in ["__add__", "__sub__", "__mul__", "__mod__", "__pow__",
73+
"__floordiv__", "__truediv__", "__lshift__", "__rshift__",
74+
"__or__", "__and__", "__xor__"]:
75+
_wrap(cls, op=op, inplace=False)
76+
_wrap(cls, op="__i{}__".format(op.strip("_")), inplace=True)
7577

7678
for op in ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]:
7779
_wrap(cls, op)
7880

79-
for op in ["__neg__"]:
81+
for op in ["__neg__", "__abs__", "__invert__"]:
8082
_wrap(cls, op, unary=True)
8183

8284
return cls
@@ -460,13 +462,31 @@ def _op(self, op, value=None, inplace=False):
460462
seq = self if inplace else self.copy()
461463

462464
if is_array_sequence(value) and seq._check_shape(value):
463-
for o1, l1, o2, l2 in zip(seq._offsets, seq._lengths, value._offsets, value._lengths):
464-
seq._data[o1:o1 + l1] = getattr(seq._data[o1:o1 + l1], op)(value._data[o2:o2 + l2])
465+
elements = zip(seq._offsets, seq._lengths,
466+
self._offsets, self._lengths,
467+
value._offsets, value._lengths)
468+
469+
# Change seq.dtype to match the operation resulting type.
470+
o0, l0, o1, l1, o2, l2 = next(elements)
471+
tmp = getattr(self._data[o1:o1 + l1], op)(value._data[o2:o2 + l2])
472+
seq._data = seq._data.astype(tmp.dtype)
473+
seq._data[o0:o0 + l0] = tmp
474+
475+
for o0, l0, o1, l1, o2, l2 in elements:
476+
seq._data[o0:o0 + l0] = getattr(self._data[o1:o1 + l1], op)(value._data[o2:o2 + l2])
465477

466478
else:
467479
args = [] if value is None else [value] # Dealing with unary and binary ops.
468-
for o1, l1 in zip(seq._offsets, seq._lengths):
469-
seq._data[o1:o1 + l1] = getattr(seq._data[o1:o1 + l1], op)(*args)
480+
elements = zip(seq._offsets, seq._lengths, self._offsets, self._lengths)
481+
482+
# Change seq.dtype to match the operation resulting type.
483+
o0, l0, o1, l1 = next(elements)
484+
tmp = getattr(self._data[o1:o1 + l1], op)(*args)
485+
seq._data = seq._data.astype(tmp.dtype)
486+
seq._data[o0:o0 + l0] = tmp
487+
488+
for o0, l0, o1, l1 in elements:
489+
seq._data[o0:o0 + l0] = getattr(self._data[o1:o1 + l1], op)(*args)
470490

471491
return seq
472492

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def setup():
2424

2525

2626
def generate_data(nb_arrays, common_shape, rng):
27-
data = [rng.rand(*(rng.randint(3, 20),) + common_shape)
27+
data = [rng.rand(*(rng.randint(3, 20),) + common_shape) * 100
2828
for _ in range(nb_arrays)]
2929
return data
3030

@@ -228,9 +228,6 @@ def test_arraysequence_getitem(self):
228228
for i, e in enumerate(SEQ_DATA['seq']):
229229
assert_array_equal(SEQ_DATA['seq'][i], e)
230230

231-
if sys.version_info < (3,):
232-
assert_array_equal(SEQ_DATA['seq'][long(i)], e)
233-
234231
# Get all items using indexing (creates a view).
235232
indices = list(range(len(SEQ_DATA['seq'])))
236233
seq_view = SEQ_DATA['seq'][indices]
@@ -278,49 +275,42 @@ def test_arraysequence_getitem(self):
278275
check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]])
279276

280277
def test_arraysequence_operators(self):
281-
for op in ["__add__", "__sub__", "__mul__", "__floordiv__", "__truediv__",
282-
"__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]:
283-
# Test math operators with a scalar.
284-
for scalar in [42, 0.5, True]:
285-
seq = getattr(SEQ_DATA['seq'], op)(scalar)
286-
assert_true(seq is not SEQ_DATA['seq'])
287-
check_arr_seq(seq, [getattr(d, op)(scalar) for d in SEQ_DATA['data']])
278+
# Disable division per zero warnings.
279+
flags = np.seterr(divide='ignore', invalid='ignore')
280+
SCALARS = [42, 0.5, True, -3, 0]
281+
CMP_OPS = ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]
288282

289-
# Test math operators with another ArraySequence.
290-
seq = getattr(SEQ_DATA['seq'], op)(SEQ_DATA['seq'])
291-
assert_true(seq is not SEQ_DATA['seq'])
292-
check_arr_seq(seq, [getattr(d, op)(d) for d in SEQ_DATA['data']])
283+
seq = SEQ_DATA['seq'].copy()
284+
seq_int = SEQ_DATA['seq'].copy()
285+
seq_int._data = seq_int._data.astype(int)
286+
seq_bool = SEQ_DATA['seq'].copy() > 30
293287

294-
# Test math operators with ArraySequence views.
295-
orig = SEQ_DATA['seq'][::2]
296-
seq = getattr(orig, op)(orig)
288+
ARRSEQS = [seq, seq_int, seq_bool]
289+
VIEWS = [seq[::2], seq_int[::2], seq_bool[::2]]
290+
291+
def _test_unary(op, arrseq):
292+
orig = arrseq.copy()
293+
seq = getattr(orig, op)()
297294
assert_true(seq is not orig)
298-
check_arr_seq(seq, [getattr(d, op)(d) for d in SEQ_DATA['data'][::2]])
299-
300-
# Test in-place operators.
301-
for op in ["__iadd__", "__isub__", "__imul__", "__ifloordiv__", "__itruediv__"]:
302-
# Test in-place math operators with a scalar.
303-
for scalar in [42, 0.5, True]:
304-
seq = seq_orig = SEQ_DATA['seq'].copy()
305-
seq = getattr(seq, op)(scalar)
306-
assert_true(seq is seq_orig)
307-
check_arr_seq(seq, [getattr(d.copy(), op)(scalar) for d in SEQ_DATA['data']])
308-
309-
# Test in-place math operators with another ArraySequence.
310-
seq = seq_orig = SEQ_DATA['seq'].copy()
311-
seq = getattr(seq, op)(SEQ_DATA['seq'])
312-
assert_true(seq is seq_orig)
313-
check_arr_seq(seq, [getattr(d.copy(), op)(d) for d in SEQ_DATA['data']])
314-
315-
# Test in-place math operators with ArraySequence views.
316-
seq = seq_orig = SEQ_DATA['seq'].copy()[::2]
317-
seq = getattr(seq, op)(seq)
318-
assert_true(seq is seq_orig)
319-
check_arr_seq(seq, [getattr(d.copy(), op)(d) for d in SEQ_DATA['data'][::2]])
295+
check_arr_seq(seq, [getattr(d, op)() for d in orig])
296+
297+
def _test_binary(op, arrseq, scalars, seqs, inplace=False):
298+
for scalar in scalars:
299+
orig = arrseq.copy()
300+
seq = getattr(orig, op)(scalar)
301+
assert_true((seq is orig) if inplace else (seq is not orig))
302+
check_arr_seq(seq, [getattr(e, op)(scalar) for e in arrseq])
303+
304+
# Test math operators with another ArraySequence.
305+
for other in seqs:
306+
orig = arrseq.copy()
307+
seq = getattr(orig, op)(other)
308+
assert_true(seq is not SEQ_DATA['seq'])
309+
check_arr_seq(seq, [getattr(e1, op)(e2) for e1, e2 in zip(arrseq, other)])
320310

321311
# Operations between array sequences of different lengths.
322-
seq = SEQ_DATA['seq'].copy()
323-
assert_raises(ValueError, getattr(seq, op), SEQ_DATA['seq'][::2])
312+
orig = arrseq.copy()
313+
assert_raises(ValueError, getattr(orig, op), orig[::2])
324314

325315
# Operations between array sequences with different amount of data.
326316
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
@@ -332,11 +322,59 @@ def test_arraysequence_operators(self):
332322
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
333323
assert_raises(ValueError, getattr(seq1, op), seq2)
334324

325+
326+
for op in ["__add__", "__sub__", "__mul__", "__mod__",
327+
"__floordiv__", "__truediv__"] + CMP_OPS:
328+
_test_binary(op, seq, SCALARS, ARRSEQS)
329+
_test_binary(op, seq_int, SCALARS, ARRSEQS)
330+
331+
# Test math operators with ArraySequence views.
332+
_test_binary(op, seq[::2], SCALARS, VIEWS)
333+
_test_binary(op, seq_int[::2], SCALARS, VIEWS)
334+
335+
if op in CMP_OPS:
336+
continue
337+
338+
op = "__i{}__".format(op.strip("_"))
339+
_test_binary(op, seq, SCALARS, ARRSEQS, inplace=True)
340+
341+
if op == "__itruediv__":
342+
continue # Going to deal with it separately.
343+
344+
_test_binary(op, seq_int, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True) # int <-- int
345+
assert_raises(TypeError, _test_binary, op, seq_int, [0.5], [], inplace=True) # int <-- float
346+
assert_raises(TypeError, _test_binary, op, seq_int, [], [seq], inplace=True) # int <-- float
347+
348+
# __pow__ : Integers to negative integer powers are not allowed.
349+
_test_binary("__pow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int])
350+
_test_binary("__ipow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True)
351+
assert_raises(ValueError, _test_binary, "__pow__", seq_int, [-3], [])
352+
assert_raises(ValueError, _test_binary, "__ipow__", seq_int, [-3], [], inplace=True)
353+
354+
# __itruediv__ is only valid with float arrseq.
355+
for scalar in SCALARS + ARRSEQS:
356+
assert_raises(TypeError, getattr(seq_int.copy(), "__itruediv__"), scalar)
357+
358+
# Bitwise operators
359+
for op in ("__lshift__", "__rshift__", "__or__", "__and__", "__xor__"):
360+
_test_binary(op, seq_bool, [42, -3, True, 0], [seq_int, seq_bool, -seq_int])
361+
assert_raises(TypeError, _test_binary, op, seq_bool, [0.5], [])
362+
assert_raises(TypeError, _test_binary, op, seq, [], [seq])
363+
335364
# Unary operators
336-
for op in ["__neg__"]:
337-
seq = getattr(SEQ_DATA['seq'], op)()
338-
assert_true(seq is not SEQ_DATA['seq'])
339-
check_arr_seq(seq, [getattr(d, op)() for d in SEQ_DATA['data']])
365+
for op in ["__neg__", "__abs__"]:
366+
_test_unary(op, seq)
367+
_test_unary(op, -seq)
368+
_test_unary(op, seq_int)
369+
_test_unary(op, -seq_int)
370+
371+
_test_unary("__abs__", seq_bool)
372+
_test_unary("__invert__", seq_bool)
373+
assert_raises(TypeError, _test_unary, "__invert__", seq)
374+
375+
# Restore flags.
376+
np.seterr(**flags)
377+
340378

341379
def test_arraysequence_setitem(self):
342380
# Set one item
@@ -346,13 +384,6 @@ def test_arraysequence_setitem(self):
346384

347385
check_arr_seq(seq, SEQ_DATA['seq'])
348386

349-
if sys.version_info < (3,):
350-
seq = ArraySequence(SEQ_DATA['seq'] * 0)
351-
for i, e in enumerate(SEQ_DATA['seq']):
352-
seq[long(i)] = e
353-
354-
check_arr_seq(seq, SEQ_DATA['seq'])
355-
356387
# Get all items using indexing (creates a view).
357388
indices = list(range(len(SEQ_DATA['seq'])))
358389
seq_view = SEQ_DATA['seq'][indices]

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,6 @@ def test_tractogram_getitem(self):
539539
for i, t in enumerate(DATA['tractogram']):
540540
assert_tractogram_item_equal(DATA['tractogram'][i], t)
541541

542-
if sys.version_info < (3,):
543-
assert_tractogram_item_equal(DATA['tractogram'][long(i)], t)
544-
545542
# Get one TractogramItem out of two.
546543
tractogram_view = DATA['simple_tractogram'][::2]
547544
check_tractogram(tractogram_view, DATA['streamlines'][::2])

0 commit comments

Comments
 (0)