Skip to content

Commit 09f4ed6

Browse files
committed
RF: Add function to easily make fake streamline for testing purposes.
1 parent eaadbb5 commit 09f4ed6

File tree

1 file changed

+121
-43
lines changed

1 file changed

+121
-43
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 121 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import warnings
55
import operator
6+
from collections import defaultdict
67

78
from nibabel.testing import assert_arrays_equal
89
from nibabel.testing import clear_and_catch_warnings
@@ -17,37 +18,99 @@
1718
DATA = {}
1819

1920

21+
def make_fake_streamline(nb_points, data_per_point_shapes={},
22+
data_for_streamline_shapes={}, rng=None):
23+
""" Make a single streamline according to provided requirements. """
24+
if rng is None:
25+
rng = np.random.RandomState()
26+
27+
streamline = rng.randn(nb_points, 3).astype("f4")
28+
29+
data_per_point = {}
30+
for k, shape in data_per_point_shapes.items():
31+
data_per_point[k] = rng.randn(*((nb_points,) + shape)).astype("f4")
32+
33+
data_for_streamline = {}
34+
for k, shape in data_for_streamline.items():
35+
data_for_streamline[k] = rng.randn(*shape).astype("f4")
36+
37+
return streamline, data_per_point, data_for_streamline
38+
39+
40+
def make_fake_tractogram(list_nb_points, data_per_point_shapes={},
41+
data_for_streamline_shapes={}, rng=None):
42+
""" Make multiple streamlines according to provided requirements. """
43+
all_streamlines = []
44+
all_data_per_point = defaultdict(lambda: [])
45+
all_data_per_streamline = defaultdict(lambda: [])
46+
for nb_points in list_nb_points:
47+
data = make_fake_streamline(nb_points, data_per_point_shapes,
48+
data_for_streamline_shapes, rng)
49+
streamline, data_per_point, data_for_streamline = data
50+
51+
all_streamlines.append(streamline)
52+
for k, v in data_per_point.items():
53+
all_data_per_point[k].append(v)
54+
55+
for k, v in data_for_streamline.items():
56+
all_data_per_streamline[k].append(v)
57+
58+
return all_streamlines, all_data_per_point, all_data_per_streamline
59+
60+
61+
def make_dummy_streamline(nb_points):
62+
""" Make the streamlines that have been used to create test data files."""
63+
if nb_points == 1:
64+
streamline = np.arange(1*3, dtype="f4").reshape((1, 3))
65+
data_per_point = {"fa": np.array([[0.2]], dtype="f4"),
66+
"colors": np.array([(1, 0, 0)]*1, dtype="f4")}
67+
data_for_streamline = {"mean_curvature": np.array([1.11], dtype="f4"),
68+
"mean_torsion": np.array([1.22], dtype="f4"),
69+
"mean_colors": np.array([1, 0, 0], dtype="f4")}
70+
71+
elif nb_points == 2:
72+
streamline = np.arange(2*3, dtype="f4").reshape((2, 3))
73+
data_per_point = {"fa": np.array([[0.3],
74+
[0.4]], dtype="f4"),
75+
"colors": np.array([(0, 1, 0)]*2, dtype="f4")}
76+
data_for_streamline = {"mean_curvature": np.array([2.11], dtype="f4"),
77+
"mean_torsion": np.array([2.22], dtype="f4"),
78+
"mean_colors": np.array([0, 1, 0], dtype="f4")}
79+
80+
elif nb_points == 5:
81+
streamline = np.arange(5*3, dtype="f4").reshape((5, 3))
82+
data_per_point = {"fa": np.array([[0.5],
83+
[0.6],
84+
[0.6],
85+
[0.7],
86+
[0.8]], dtype="f4"),
87+
"colors": np.array([(0, 0, 1)]*5, dtype="f4")}
88+
data_for_streamline = {"mean_curvature": np.array([3.11], dtype="f4"),
89+
"mean_torsion": np.array([3.22], dtype="f4"),
90+
"mean_colors": np.array([0, 0, 1], dtype="f4")}
91+
92+
return streamline, data_per_point, data_for_streamline
93+
94+
2095
def setup():
2196
global DATA
2297
DATA['rng'] = np.random.RandomState(1234)
23-
DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)),
24-
np.arange(2*3, dtype="f4").reshape((2, 3)),
25-
np.arange(5*3, dtype="f4").reshape((5, 3))]
26-
27-
DATA['fa'] = [np.array([[0.2]], dtype="f4"),
28-
np.array([[0.3],
29-
[0.4]], dtype="f4"),
30-
np.array([[0.5],
31-
[0.6],
32-
[0.6],
33-
[0.7],
34-
[0.8]], dtype="f4")]
35-
36-
DATA['colors'] = [np.array([(1, 0, 0)]*1, dtype="f4"),
37-
np.array([(0, 1, 0)]*2, dtype="f4"),
38-
np.array([(0, 0, 1)]*5, dtype="f4")]
39-
40-
DATA['mean_curvature'] = [np.array([1.11], dtype="f4"),
41-
np.array([2.11], dtype="f4"),
42-
np.array([3.11], dtype="f4")]
43-
44-
DATA['mean_torsion'] = [np.array([1.22], dtype="f4"),
45-
np.array([2.22], dtype="f4"),
46-
np.array([3.22], dtype="f4")]
47-
48-
DATA['mean_colors'] = [np.array([1, 0, 0], dtype="f4"),
49-
np.array([0, 1, 0], dtype="f4"),
50-
np.array([0, 0, 1], dtype="f4")]
98+
99+
DATA['streamlines'] = []
100+
DATA['fa'] = []
101+
DATA['colors'] = []
102+
DATA['mean_curvature'] = []
103+
DATA['mean_torsion'] = []
104+
DATA['mean_colors'] = []
105+
for nb_points in [1, 2, 5]:
106+
data = make_dummy_streamline(nb_points)
107+
streamline, data_per_point, data_for_streamline = data
108+
DATA['streamlines'].append(streamline)
109+
DATA['fa'].append(data_per_point['fa'])
110+
DATA['colors'].append(data_per_point['colors'])
111+
DATA['mean_curvature'].append(data_for_streamline['mean_curvature'])
112+
DATA['mean_torsion'].append(data_for_streamline['mean_torsion'])
113+
DATA['mean_colors'].append(data_for_streamline['mean_colors'])
51114

52115
DATA['data_per_point'] = {'colors': DATA['colors'],
53116
'fa': DATA['fa']}
@@ -280,9 +343,14 @@ def test_extend(self):
280343
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows
281344
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
282345

283-
new_data = {'colors': 2 * np.array(DATA['colors']),
284-
'fa': 3 * np.array(DATA['fa'])}
285-
sdict2 = PerArraySequenceDict(total_nb_rows, new_data)
346+
# Test compatible PerArrayDicts.
347+
list_nb_points = [2, 7, 4]
348+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
349+
"fa": DATA['fa'][0].shape[1:]}
350+
_, new_data, _ = make_fake_tractogram(list_nb_points,
351+
data_per_point_shapes,
352+
rng=DATA['rng'])
353+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
286354

287355
sdict.extend(sdict2)
288356
assert_equal(len(sdict), len(sdict2))
@@ -297,16 +365,22 @@ def test_extend(self):
297365
assert_raises(ValueError, sdict.extend, PerArraySequenceDict())
298366

299367
# Other dict has more entries.
300-
new_data = {'colors': 2 * np.array(DATA['colors']),
301-
'fa': 3 * np.array(DATA['fa']),
302-
'other': 4 * np.array(DATA['fa'])}
303-
sdict2 = PerArraySequenceDict(total_nb_rows, new_data)
368+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
369+
"fa": DATA['fa'][0].shape[1:],
370+
"other": (7,)}
371+
_, new_data, _ = make_fake_tractogram(list_nb_points,
372+
data_per_point_shapes,
373+
rng=DATA['rng'])
374+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
304375
assert_raises(ValueError, sdict.extend, sdict2)
305376

306377
# Other dict has the right number of entries but wrong shape.
307-
new_data = {'colors': 2 * np.array(DATA['colors']),
308-
'other': 2 * np.array(DATA['colors']),}
309-
sdict2 = PerArraySequenceDict(total_nb_rows, new_data)
378+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
379+
"fa": DATA['fa'][0].shape[1:] + (3,)}
380+
_, new_data, _ = make_fake_tractogram(list_nb_points,
381+
data_per_point_shapes,
382+
rng=DATA['rng'])
383+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
310384
assert_raises(ValueError, sdict.extend, sdict2)
311385

312386

@@ -650,13 +724,15 @@ def test_tractogram_extend(self):
650724
# Load tractogram that contains some metadata.
651725
t = DATA['tractogram'].copy()
652726

653-
for op, in_place in ((operator.add, False), (operator.iadd, True), (extender, True)):
727+
for op, in_place in ((operator.add, False), (operator.iadd, True),
728+
(extender, True)):
654729
first_arg = t.copy()
655730
new_t = op(first_arg, t)
656731
assert_equal(new_t is first_arg, in_place)
657732
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
658733
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
659734

735+
660736
class TestLazyTractogram(unittest.TestCase):
661737

662738
def test_lazy_tractogram_creation(self):
@@ -670,7 +746,8 @@ def test_lazy_tractogram_creation(self):
670746
'mean_colors': (x for x in DATA['mean_colors'])}
671747

672748
# Creating LazyTractogram with generators is not allowed as
673-
# generators get exhausted and are not reusable unlike generator function.
749+
# generators get exhausted and are not reusable unlike generator
750+
# function.
674751
assert_raises(TypeError, LazyTractogram, streamlines)
675752
assert_raises(TypeError, LazyTractogram,
676753
data_per_streamline=data_per_streamline)
@@ -701,7 +778,8 @@ def test_lazy_tractogram_from_data_func(self):
701778
tractogram = LazyTractogram.from_data_func(_empty_data_gen)
702779
check_tractogram(tractogram)
703780

704-
# Create `LazyTractogram` from a generator function yielding TractogramItem.
781+
# Create `LazyTractogram` from a generator function yielding
782+
# TractogramItem.
705783
data = [DATA['streamlines'], DATA['fa'], DATA['colors'],
706784
DATA['mean_curvature'], DATA['mean_torsion'],
707785
DATA['mean_colors']]
@@ -839,8 +917,8 @@ def test_lazy_tractogram_copy(self):
839917
# Check we copied the data and not simply created new references.
840918
assert_true(tractogram is not DATA['lazy_tractogram'])
841919

842-
# When copying LazyTractogram, the generator function yielding streamlines
843-
# should stay the same.
920+
# When copying LazyTractogram, the generator function yielding
921+
# streamlines should stay the same.
844922
assert_true(tractogram._streamlines
845923
is DATA['lazy_tractogram']._streamlines)
846924

0 commit comments

Comments
 (0)