@@ -40,7 +40,8 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
40
40
in_frac = InputMultiPath (File (exists = True ), mandatory = True ,
41
41
desc = ('volume fraction of each fiber' ))
42
42
in_vfms = InputMultiPath (File (exists = True ), mandatory = True ,
43
- desc = 'volume fraction map' )
43
+ desc = ('volume fractions of isotropic '
44
+ 'compartiments' ))
44
45
in_mask = File (exists = True , desc = 'mask to simulate data' )
45
46
46
47
n_proc = traits .Int (0 , usedefault = True , desc = 'number of processes' )
@@ -60,7 +61,7 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
60
61
desc = 'file with the mask simulated' )
61
62
out_bvec = File ('bvec.sim' , usedefault = True , desc = 'simulated b vectors' )
62
63
out_bval = File ('bval.sim' , usedefault = True , desc = 'simulated b values' )
63
- snr = traits .Int (30 , usedefault = True , desc = 'signal-to-noise ratio (dB)' )
64
+ snr = traits .Int (0 , usedefault = True , desc = 'signal-to-noise ratio (dB)' )
64
65
65
66
66
67
class SimulateMultiTensorOutputSpec (TraitedSpec ):
@@ -100,38 +101,47 @@ def _run_interface(self, runtime):
100
101
shape = b0_im .get_shape ()
101
102
aff = b0_im .get_affine ()
102
103
104
+ # Check and load sticks and their volume fractions
105
+ nsticks = len (self .inputs .in_dirs )
106
+ if len (self .inputs .in_frac ) != nsticks :
107
+ raise RuntimeError (('Number of sticks and their volume fractions'
108
+ ' must match.' ))
109
+
103
110
ffsim = nb .concat_images ([nb .load (f ) for f in self .inputs .in_frac ])
104
111
ffs = np .squeeze (ffsim .get_data ()) # fiber fractions
112
+ if nsticks == 1 :
113
+ ffs = ffs [..., np .newaxis ]
105
114
115
+ # Volume fractions of isotropic compartiments
106
116
vfsim = nb .concat_images ([nb .load (f ) for f in self .inputs .in_vfms ])
107
117
vfs = np .squeeze (vfsim .get_data ()) # volume fractions
108
118
109
119
total_ff = np .sum (ffs , axis = 3 )
110
120
total_vf = np .sum (vfs , axis = 3 )
111
121
112
- msk = np .zeros (shape , dtype = np .uint8 )
113
- msk [(total_vf > 0.0 )] = 1
114
-
122
+ # Generate a mask
115
123
if isdefined (self .inputs .in_mask ):
116
124
msk = nb .load (self .inputs .in_mask ).get_data ()
117
125
msk [msk > 0.0 ] = 1.0
118
126
msk [msk < 1.0 ] = 0.0
127
+ else :
128
+ msk = np .zeros (shape , dtype = np .uint8 )
129
+ msk [total_vf > 0.0 ] = 1
119
130
120
131
mhdr = hdr .copy ()
121
132
mhdr .set_data_dtype (np .uint8 )
122
133
mhdr .set_xyzt_units ('mm' , 'sec' )
123
134
nb .Nifti1Image (msk , aff , mhdr ).to_filename (
124
135
op .abspath (self .inputs .out_mask ))
125
136
137
+ # Initialize stack of args
126
138
args = np .hstack ((vfs [msk > 0 ], ffs [msk > 0 ]))
127
139
140
+ # Stack directions
128
141
for f in self .inputs .in_dirs :
129
142
fd = nb .load (f ).get_data ()
130
143
args = np .hstack ((args , fd [msk > 0 ]))
131
144
132
- b0 = np .array ([b0_im .get_data ()[msk > 0 ]]).T
133
- args = np .hstack ((args , b0 ))
134
-
135
145
if isdefined (self .inputs .in_bval ) and isdefined (self .inputs .in_bvec ):
136
146
# Load the gradient strengths and directions
137
147
bvals = np .loadtxt (self .inputs .in_bval )
@@ -147,7 +157,7 @@ def _run_interface(self, runtime):
147
157
np .savetxt (op .abspath (self .inputs .out_bval ), gtab .bvals )
148
158
149
159
snr = self .inputs .snr
150
- args = [tuple (np . hstack (( r , gtab , snr ) )) for r in args ]
160
+ args = [tuple ([ nsticks , gtab , snr ] + r . tolist ( )) for r in args ]
151
161
152
162
n_proc = self .inputs .n_proc
153
163
if n_proc == 0 :
@@ -160,12 +170,22 @@ def _run_interface(self, runtime):
160
170
161
171
iflogger .info (('Starting simulation of %d voxels, %d diffusion'
162
172
' directions.' ) % (len (args ), len (gtab .bvals )))
163
- result = pool .map (_compute_voxel , args )
164
- ndirs = np .shape (result )[1 ]
173
+
174
+ result = np .array (pool .map (_compute_voxel , args ))
175
+
176
+ ndirs = len (gtab .bvals )
177
+ if np .shape (result )[1 ] != ndirs :
178
+ raise RuntimeError (('Computed directions do not match number'
179
+ 'of b-values.' ))
165
180
166
181
simulated = np .zeros ((shape [0 ], shape [1 ], shape [2 ], ndirs ))
167
182
simulated [msk > 0 ] = result
168
183
184
+ # S0
185
+ b0 = b0_im .get_data ()
186
+ for i in xrange (ndirs ):
187
+ simulated [..., i ] *= b0
188
+
169
189
simhdr = hdr .copy ()
170
190
simhdr .set_data_dtype (np .float32 )
171
191
simhdr .set_xyzt_units ('mm' , 'sec' )
@@ -202,42 +222,41 @@ def _compute_voxel(args):
202
222
D_ball = [3000e-6 , 960e-6 , 680e-6 ]
203
223
sf_evals = [1700e-6 , 200e-6 , 200e-6 ]
204
224
205
- vfs = [args [0 ], args [1 ], args [2 ]]
206
- ffs = [args [3 ], args [4 ], args [5 ]] # single fiber fractions
207
- sticks = [(args [6 ], args [7 ], args [8 ]),
208
- (args [8 ], args [10 ], args [11 ]),
209
- (args [12 ], args [13 ], args [14 ])]
225
+ nf = args [0 ] # number of fibers
226
+ gtab = args [1 ] # gradient table
227
+ snr = args [2 ]
228
+ vfs = args [3 :6 ]
229
+
230
+ vfs = (np .array (vfs ) / np .sum (vfs ))
231
+
232
+ sst = 6 + nf
233
+ ffs = args [6 :sst ] # single fiber fractions
210
234
211
- S0 = args [15 ]
212
- gtab = args [ 16 ]
235
+ sticks = [ tuple ( args [sst + i * 3 : sst + 3 + i * 3 ])
236
+ for i in range ( 0 , nf ) ]
213
237
214
- nf = len (ffs )
215
238
mevals = [sf_evals ] * nf
216
239
sf_vf = np .sum (ffs )
217
- ffs = ((np .array (ffs ) / sf_vf ) * 100 )
218
240
219
241
# Simulate sticks
220
- signal , _ = multi_tensor (gtab , np .array (mevals ), S0 = 1 ,
221
- angles = sticks , fractions = ffs , snr = None )
222
- signal *= sf_vf
242
+ if sf_vf > 1.0e-3 :
243
+ ffs = ((np .array (ffs ) / sf_vf ) * 100 )
244
+ signal , _ = multi_tensor (gtab , np .array (mevals ), S0 = 1.0 ,
245
+ angles = sticks , fractions = ffs , snr = None )
246
+ else :
247
+ signal = np .zeros_like (gtab .bvals , dtype = np .float32 )
248
+
249
+ signal *= vfs [2 ] * sf_vf
223
250
224
251
# Simulate balls
225
- r = 1.0 - sf_vf
226
- if r > 1.0e-3 :
227
- for vf , d in zip (vfs , D_ball ):
228
- f0 = vf * r
229
- signal += f0 * np .exp (- gtab .bvals * d )
230
-
231
- snr = None
232
- try :
233
- snr = args [17 ]
234
- except IndexError :
235
- pass
236
-
237
- if snr is not None and snr >= 0 :
238
- signal [1 :] = add_noise (signal [1 :], snr , 1 )
239
-
240
- return signal * S0
252
+ vfs [2 ] *= (1 - sf_vf )
253
+ for f0 , d in zip (vfs , D_ball ):
254
+ signal += f0 * np .exp (- gtab .bvals * d )
255
+
256
+ if snr > 0 :
257
+ signal = add_noise (signal , snr , 1 )
258
+
259
+ return signal .tolist ()
241
260
242
261
243
262
def _generate_gradients (ndirs = 64 , values = [1000 , 3000 ], nb0s = 1 ):
0 commit comments