3
3
import numpy as np
4
4
import warnings
5
5
import operator
6
+ from collections import defaultdict
6
7
7
8
from nibabel .testing import assert_arrays_equal
8
9
from nibabel .testing import clear_and_catch_warnings
17
18
DATA = {}
18
19
19
20
21
+ def make_fake_streamline (nb_points , data_per_point_shapes = {},
22
+ data_for_streamline_shapes = {}, rng = None ):
23
+ """ Make a single streamline according to provided requirements. """
24
+ if rng is None :
25
+ rng = np .random .RandomState ()
26
+
27
+ streamline = rng .randn (nb_points , 3 ).astype ("f4" )
28
+
29
+ data_per_point = {}
30
+ for k , shape in data_per_point_shapes .items ():
31
+ data_per_point [k ] = rng .randn (* ((nb_points ,) + shape )).astype ("f4" )
32
+
33
+ data_for_streamline = {}
34
+ for k , shape in data_for_streamline .items ():
35
+ data_for_streamline [k ] = rng .randn (* shape ).astype ("f4" )
36
+
37
+ return streamline , data_per_point , data_for_streamline
38
+
39
+
40
+ def make_fake_tractogram (list_nb_points , data_per_point_shapes = {},
41
+ data_for_streamline_shapes = {}, rng = None ):
42
+ """ Make multiple streamlines according to provided requirements. """
43
+ all_streamlines = []
44
+ all_data_per_point = defaultdict (lambda : [])
45
+ all_data_per_streamline = defaultdict (lambda : [])
46
+ for nb_points in list_nb_points :
47
+ data = make_fake_streamline (nb_points , data_per_point_shapes ,
48
+ data_for_streamline_shapes , rng )
49
+ streamline , data_per_point , data_for_streamline = data
50
+
51
+ all_streamlines .append (streamline )
52
+ for k , v in data_per_point .items ():
53
+ all_data_per_point [k ].append (v )
54
+
55
+ for k , v in data_for_streamline .items ():
56
+ all_data_per_streamline [k ].append (v )
57
+
58
+ return all_streamlines , all_data_per_point , all_data_per_streamline
59
+
60
+
61
+ def make_dummy_streamline (nb_points ):
62
+ """ Make the streamlines that have been used to create test data files."""
63
+ if nb_points == 1 :
64
+ streamline = np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 ))
65
+ data_per_point = {"fa" : np .array ([[0.2 ]], dtype = "f4" ),
66
+ "colors" : np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" )}
67
+ data_for_streamline = {"mean_curvature" : np .array ([1.11 ], dtype = "f4" ),
68
+ "mean_torsion" : np .array ([1.22 ], dtype = "f4" ),
69
+ "mean_colors" : np .array ([1 , 0 , 0 ], dtype = "f4" )}
70
+
71
+ elif nb_points == 2 :
72
+ streamline = np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 ))
73
+ data_per_point = {"fa" : np .array ([[0.3 ],
74
+ [0.4 ]], dtype = "f4" ),
75
+ "colors" : np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" )}
76
+ data_for_streamline = {"mean_curvature" : np .array ([2.11 ], dtype = "f4" ),
77
+ "mean_torsion" : np .array ([2.22 ], dtype = "f4" ),
78
+ "mean_colors" : np .array ([0 , 1 , 0 ], dtype = "f4" )}
79
+
80
+ elif nb_points == 5 :
81
+ streamline = np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))
82
+ data_per_point = {"fa" : np .array ([[0.5 ],
83
+ [0.6 ],
84
+ [0.6 ],
85
+ [0.7 ],
86
+ [0.8 ]], dtype = "f4" ),
87
+ "colors" : np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )}
88
+ data_for_streamline = {"mean_curvature" : np .array ([3.11 ], dtype = "f4" ),
89
+ "mean_torsion" : np .array ([3.22 ], dtype = "f4" ),
90
+ "mean_colors" : np .array ([0 , 0 , 1 ], dtype = "f4" )}
91
+
92
+ return streamline , data_per_point , data_for_streamline
93
+
94
+
20
95
def setup ():
21
96
global DATA
22
97
DATA ['rng' ] = np .random .RandomState (1234 )
23
- DATA ['streamlines' ] = [np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 )),
24
- np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 )),
25
- np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))]
26
-
27
- DATA ['fa' ] = [np .array ([[0.2 ]], dtype = "f4" ),
28
- np .array ([[0.3 ],
29
- [0.4 ]], dtype = "f4" ),
30
- np .array ([[0.5 ],
31
- [0.6 ],
32
- [0.6 ],
33
- [0.7 ],
34
- [0.8 ]], dtype = "f4" )]
35
-
36
- DATA ['colors' ] = [np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" ),
37
- np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" ),
38
- np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )]
39
-
40
- DATA ['mean_curvature' ] = [np .array ([1.11 ], dtype = "f4" ),
41
- np .array ([2.11 ], dtype = "f4" ),
42
- np .array ([3.11 ], dtype = "f4" )]
43
-
44
- DATA ['mean_torsion' ] = [np .array ([1.22 ], dtype = "f4" ),
45
- np .array ([2.22 ], dtype = "f4" ),
46
- np .array ([3.22 ], dtype = "f4" )]
47
-
48
- DATA ['mean_colors' ] = [np .array ([1 , 0 , 0 ], dtype = "f4" ),
49
- np .array ([0 , 1 , 0 ], dtype = "f4" ),
50
- np .array ([0 , 0 , 1 ], dtype = "f4" )]
98
+
99
+ DATA ['streamlines' ] = []
100
+ DATA ['fa' ] = []
101
+ DATA ['colors' ] = []
102
+ DATA ['mean_curvature' ] = []
103
+ DATA ['mean_torsion' ] = []
104
+ DATA ['mean_colors' ] = []
105
+ for nb_points in [1 , 2 , 5 ]:
106
+ data = make_dummy_streamline (nb_points )
107
+ streamline , data_per_point , data_for_streamline = data
108
+ DATA ['streamlines' ].append (streamline )
109
+ DATA ['fa' ].append (data_per_point ['fa' ])
110
+ DATA ['colors' ].append (data_per_point ['colors' ])
111
+ DATA ['mean_curvature' ].append (data_for_streamline ['mean_curvature' ])
112
+ DATA ['mean_torsion' ].append (data_for_streamline ['mean_torsion' ])
113
+ DATA ['mean_colors' ].append (data_for_streamline ['mean_colors' ])
51
114
52
115
DATA ['data_per_point' ] = {'colors' : DATA ['colors' ],
53
116
'fa' : DATA ['fa' ]}
@@ -280,9 +343,14 @@ def test_extend(self):
280
343
total_nb_rows = DATA ['tractogram' ].streamlines .total_nb_rows
281
344
sdict = PerArraySequenceDict (total_nb_rows , DATA ['data_per_point' ])
282
345
283
- new_data = {'colors' : 2 * np .array (DATA ['colors' ]),
284
- 'fa' : 3 * np .array (DATA ['fa' ])}
285
- sdict2 = PerArraySequenceDict (total_nb_rows , new_data )
346
+ # Test compatible PerArrayDicts.
347
+ list_nb_points = [2 , 7 , 4 ]
348
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
349
+ "fa" : DATA ['fa' ][0 ].shape [1 :]}
350
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
351
+ data_per_point_shapes ,
352
+ rng = DATA ['rng' ])
353
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
286
354
287
355
sdict .extend (sdict2 )
288
356
assert_equal (len (sdict ), len (sdict2 ))
@@ -297,16 +365,22 @@ def test_extend(self):
297
365
assert_raises (ValueError , sdict .extend , PerArraySequenceDict ())
298
366
299
367
# Other dict has more entries.
300
- new_data = {'colors' : 2 * np .array (DATA ['colors' ]),
301
- 'fa' : 3 * np .array (DATA ['fa' ]),
302
- 'other' : 4 * np .array (DATA ['fa' ])}
303
- sdict2 = PerArraySequenceDict (total_nb_rows , new_data )
368
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
369
+ "fa" : DATA ['fa' ][0 ].shape [1 :],
370
+ "other" : (7 ,)}
371
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
372
+ data_per_point_shapes ,
373
+ rng = DATA ['rng' ])
374
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
304
375
assert_raises (ValueError , sdict .extend , sdict2 )
305
376
306
377
# Other dict has the right number of entries but wrong shape.
307
- new_data = {'colors' : 2 * np .array (DATA ['colors' ]),
308
- 'other' : 2 * np .array (DATA ['colors' ]),}
309
- sdict2 = PerArraySequenceDict (total_nb_rows , new_data )
378
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
379
+ "fa" : DATA ['fa' ][0 ].shape [1 :] + (3 ,)}
380
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
381
+ data_per_point_shapes ,
382
+ rng = DATA ['rng' ])
383
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
310
384
assert_raises (ValueError , sdict .extend , sdict2 )
311
385
312
386
@@ -650,13 +724,15 @@ def test_tractogram_extend(self):
650
724
# Load tractogram that contains some metadata.
651
725
t = DATA ['tractogram' ].copy ()
652
726
653
- for op , in_place in ((operator .add , False ), (operator .iadd , True ), (extender , True )):
727
+ for op , in_place in ((operator .add , False ), (operator .iadd , True ),
728
+ (extender , True )):
654
729
first_arg = t .copy ()
655
730
new_t = op (first_arg , t )
656
731
assert_equal (new_t is first_arg , in_place )
657
732
assert_tractogram_equal (new_t [:len (t )], DATA ['tractogram' ])
658
733
assert_tractogram_equal (new_t [len (t ):], DATA ['tractogram' ])
659
734
735
+
660
736
class TestLazyTractogram (unittest .TestCase ):
661
737
662
738
def test_lazy_tractogram_creation (self ):
@@ -670,7 +746,8 @@ def test_lazy_tractogram_creation(self):
670
746
'mean_colors' : (x for x in DATA ['mean_colors' ])}
671
747
672
748
# Creating LazyTractogram with generators is not allowed as
673
- # generators get exhausted and are not reusable unlike generator function.
749
+ # generators get exhausted and are not reusable unlike generator
750
+ # function.
674
751
assert_raises (TypeError , LazyTractogram , streamlines )
675
752
assert_raises (TypeError , LazyTractogram ,
676
753
data_per_streamline = data_per_streamline )
@@ -701,7 +778,8 @@ def test_lazy_tractogram_from_data_func(self):
701
778
tractogram = LazyTractogram .from_data_func (_empty_data_gen )
702
779
check_tractogram (tractogram )
703
780
704
- # Create `LazyTractogram` from a generator function yielding TractogramItem.
781
+ # Create `LazyTractogram` from a generator function yielding
782
+ # TractogramItem.
705
783
data = [DATA ['streamlines' ], DATA ['fa' ], DATA ['colors' ],
706
784
DATA ['mean_curvature' ], DATA ['mean_torsion' ],
707
785
DATA ['mean_colors' ]]
@@ -839,8 +917,8 @@ def test_lazy_tractogram_copy(self):
839
917
# Check we copied the data and not simply created new references.
840
918
assert_true (tractogram is not DATA ['lazy_tractogram' ])
841
919
842
- # When copying LazyTractogram, the generator function yielding streamlines
843
- # should stay the same.
920
+ # When copying LazyTractogram, the generator function yielding
921
+ # streamlines should stay the same.
844
922
assert_true (tractogram ._streamlines
845
923
is DATA ['lazy_tractogram' ]._streamlines )
846
924
0 commit comments