@@ -45,11 +45,11 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
45
45
in_mask = File (exists = True , desc = 'mask to simulate data' )
46
46
47
47
diff_iso = traits .List (
48
- traits . Float , default = [3000e-6 , 960e-6 , 680e-6 ], usedefault = True ,
48
+ [3000e-6 , 960e-6 , 680e-6 ], traits . Float , usedefault = True ,
49
49
desc = 'Diffusivity of isotropic compartments' )
50
50
diff_sf = traits .Tuple (
51
- traits . Float , traits . Float , traits . Float ,
52
- default = ( 1700e-6 , 200e-6 , 200e-6 ) , usedefault = True ,
51
+ ( 1700e-6 , 200e-6 , 200e-6 ) ,
52
+ traits . Float , traits . Float , traits . Float , usedefault = True ,
53
53
desc = 'Single fiber tensor' )
54
54
55
55
n_proc = traits .Int (0 , usedefault = True , desc = 'number of processes' )
@@ -128,14 +128,35 @@ def _run_interface(self, runtime):
128
128
raise RuntimeError (('Number of sticks and their volume fractions'
129
129
' must match.' ))
130
130
131
- ffsim = nb .concat_images ([nb .load (f ) for f in self .inputs .in_frac ])
132
- ffs = np .squeeze (ffsim .get_data ()) # fiber fractions
133
- ffs [ffs > 1.0 ] = 1.0
134
- ffs [ffs < 0.0 ] = 0.0
131
+ # Volume fractions of isotropic compartments
132
+ nballs = len (self .inputs .in_vfms )
133
+ vfs = np .squeeze (nb .concat_images ([nb .load (f ) for f in self .inputs .in_vfms ]).get_data ())
134
+ if nballs == 1 :
135
+ vfs = vfs [..., np .newaxis ]
136
+ total_vf = np .sum (vfs , axis = 3 )
135
137
138
+ # Generate a mask
139
+ if isdefined (self .inputs .in_mask ):
140
+ msk = nb .load (self .inputs .in_mask ).get_data ()
141
+ msk [msk > 0.0 ] = 1.0
142
+ msk [msk < 1.0 ] = 0.0
143
+ else :
144
+ msk = np .zeros (shape )
145
+ msk [total_vf > 0.0 ] = 1.0
146
+
147
+ msk = np .clip (msk , 0.0 , 1.0 )
148
+ nvox = len (msk [msk > 0 ])
149
+
150
+ # Fiber fractions
151
+ ffsim = nb .concat_images ([nb .load (f ) for f in self .inputs .in_frac ])
152
+ ffs = np .nan_to_num (np .squeeze (ffsim .get_data ())) # fiber fractions
153
+ ffs = np .clip (ffs , 0. , 1. )
136
154
if nsticks == 1 :
137
155
ffs = ffs [..., np .newaxis ]
138
156
157
+ for i in range (nsticks ):
158
+ ffs [..., i ] *= msk
159
+
139
160
total_ff = np .sum (ffs , axis = 3 )
140
161
141
162
# Fix incongruencies in fiber fractions
@@ -147,33 +168,14 @@ def _run_interface(self, runtime):
147
168
ffs [ffs < 0.0 ] = 0.0
148
169
total_ff = np .sum (ffs , axis = 3 )
149
170
150
- # Volume fractions of isotropic compartiments
151
- nballs = len (self .inputs .in_vfms )
152
- vfs = np .squeeze (nb .concat_images ([nb .load (f ) for f in self .inputs .in_vfms ]).get_data ())
153
- if nsticks == 1 :
154
- vfs = vfs [..., np .newaxis ]
155
-
156
-
157
171
for i in range (vfs .shape [- 1 ]):
158
172
vfs [..., i ] -= total_ff
159
- vfs [ vfs < 0.0 ] = 0
173
+ vfs = np . clip ( vfs , 0. , 1. )
160
174
161
175
fractions = np .concatenate ((ffs , vfs ), axis = 3 )
162
- total_vf = np .sum (fractions , axis = 3 )
163
176
nb .Nifti1Image (fractions , aff , None ).to_filename ('fractions.nii.gz' )
164
177
nb .Nifti1Image (total_vf , aff , None ).to_filename ('total_vf.nii.gz' )
165
178
166
- # Generate a mask
167
- if isdefined (self .inputs .in_mask ):
168
- msk = nb .load (self .inputs .in_mask ).get_data ()
169
- msk [msk > 0.0 ] = 1.0
170
- msk [msk < 1.0 ] = 0.0
171
- else :
172
- msk = np .zeros (shape , dtype = np .uint8 )
173
- msk [total_vf > 0.0 ] = 1
174
-
175
- nvox = len (mask [mask > 0 ])
176
-
177
179
mhdr = hdr .copy ()
178
180
mhdr .set_data_dtype (np .uint8 )
179
181
mhdr .set_xyzt_units ('mm' , 'sec' )
@@ -194,19 +196,18 @@ def _run_interface(self, runtime):
194
196
195
197
196
198
sf_evals = list (self .inputs .diff_sf )
197
- ba_evals = self .inputs .diff_iso
199
+ ba_evals = list ( self .inputs .diff_iso )
198
200
201
+ mevals = [sf_evals ] * nsticks + [[ba_evals [d ]]* 3 for d in range (nballs )]
199
202
args = []
200
203
for i in range (nvox ):
201
204
args .append (
202
205
{'fractions' : fracs [i , ...].tolist (),
203
- 'sticks' : [( 1.0 , 0.0 , 0.0 )] * nballs + dirs [ i , ...]. tolist () ,
206
+ 'sticks' : [tuple ( dirs [ i , j : j + 3 ]) for j in range ( nsticks )] + [( 1.0 , 0.0 , 0.0 )] * nballs ,
204
207
'gradients' : gtab ,
205
- 'mevals' : [[ ba_evals [ d ] * 3 ] for d in range ( nballs )] + [ sf_evals ] * nsticks
208
+ 'mevals' : mevals
206
209
})
207
210
208
- print args [:5 ]
209
-
210
211
n_proc = self .inputs .n_proc
211
212
if n_proc == 0 :
212
213
n_proc = cpu_count ()
0 commit comments