Skip to content

Commit 0ef93e8

Browse files
committed
allow for N number of fibers per voxel
1 parent 11ada63 commit 0ef93e8

File tree

1 file changed

+58
-39
lines changed

1 file changed

+58
-39
lines changed

nipype/interfaces/dipy/simulate.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
4040
in_frac = InputMultiPath(File(exists=True), mandatory=True,
4141
desc=('volume fraction of each fiber'))
4242
in_vfms = InputMultiPath(File(exists=True), mandatory=True,
43-
desc='volume fraction map')
43+
desc=('volume fractions of isotropic '
44+
'compartiments'))
4445
in_mask = File(exists=True, desc='mask to simulate data')
4546

4647
n_proc = traits.Int(0, usedefault=True, desc='number of processes')
@@ -60,7 +61,7 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
6061
desc='file with the mask simulated')
6162
out_bvec = File('bvec.sim', usedefault=True, desc='simulated b vectors')
6263
out_bval = File('bval.sim', usedefault=True, desc='simulated b values')
63-
snr = traits.Int(30, usedefault=True, desc='signal-to-noise ratio (dB)')
64+
snr = traits.Int(0, usedefault=True, desc='signal-to-noise ratio (dB)')
6465

6566

6667
class SimulateMultiTensorOutputSpec(TraitedSpec):
@@ -100,38 +101,47 @@ def _run_interface(self, runtime):
100101
shape = b0_im.get_shape()
101102
aff = b0_im.get_affine()
102103

104+
# Check and load sticks and their volume fractions
105+
nsticks = len(self.inputs.in_dirs)
106+
if len(self.inputs.in_frac) != nsticks:
107+
raise RuntimeError(('Number of sticks and their volume fractions'
108+
' must match.'))
109+
103110
ffsim = nb.concat_images([nb.load(f) for f in self.inputs.in_frac])
104111
ffs = np.squeeze(ffsim.get_data()) # fiber fractions
112+
if nsticks == 1:
113+
ffs = ffs[..., np.newaxis]
105114

115+
# Volume fractions of isotropic compartiments
106116
vfsim = nb.concat_images([nb.load(f) for f in self.inputs.in_vfms])
107117
vfs = np.squeeze(vfsim.get_data()) # volume fractions
108118

109119
total_ff = np.sum(ffs, axis=3)
110120
total_vf = np.sum(vfs, axis=3)
111121

112-
msk = np.zeros(shape, dtype=np.uint8)
113-
msk[(total_vf > 0.0)] = 1
114-
122+
# Generate a mask
115123
if isdefined(self.inputs.in_mask):
116124
msk = nb.load(self.inputs.in_mask).get_data()
117125
msk[msk > 0.0] = 1.0
118126
msk[msk < 1.0] = 0.0
127+
else:
128+
msk = np.zeros(shape, dtype=np.uint8)
129+
msk[total_vf > 0.0] = 1
119130

120131
mhdr = hdr.copy()
121132
mhdr.set_data_dtype(np.uint8)
122133
mhdr.set_xyzt_units('mm', 'sec')
123134
nb.Nifti1Image(msk, aff, mhdr).to_filename(
124135
op.abspath(self.inputs.out_mask))
125136

137+
# Initialize stack of args
126138
args = np.hstack((vfs[msk > 0], ffs[msk > 0]))
127139

140+
# Stack directions
128141
for f in self.inputs.in_dirs:
129142
fd = nb.load(f).get_data()
130143
args = np.hstack((args, fd[msk > 0]))
131144

132-
b0 = np.array([b0_im.get_data()[msk > 0]]).T
133-
args = np.hstack((args, b0))
134-
135145
if isdefined(self.inputs.in_bval) and isdefined(self.inputs.in_bvec):
136146
# Load the gradient strengths and directions
137147
bvals = np.loadtxt(self.inputs.in_bval)
@@ -147,7 +157,7 @@ def _run_interface(self, runtime):
147157
np.savetxt(op.abspath(self.inputs.out_bval), gtab.bvals)
148158

149159
snr = self.inputs.snr
150-
args = [tuple(np.hstack((r, gtab, snr))) for r in args]
160+
args = [tuple([nsticks, gtab, snr] + r.tolist()) for r in args]
151161

152162
n_proc = self.inputs.n_proc
153163
if n_proc == 0:
@@ -160,12 +170,22 @@ def _run_interface(self, runtime):
160170

161171
iflogger.info(('Starting simulation of %d voxels, %d diffusion'
162172
' directions.') % (len(args), len(gtab.bvals)))
163-
result = pool.map(_compute_voxel, args)
164-
ndirs = np.shape(result)[1]
173+
174+
result = np.array(pool.map(_compute_voxel, args))
175+
176+
ndirs = len(gtab.bvals)
177+
if np.shape(result)[1] != ndirs:
178+
raise RuntimeError(('Computed directions do not match number'
179+
'of b-values.'))
165180

166181
simulated = np.zeros((shape[0], shape[1], shape[2], ndirs))
167182
simulated[msk > 0] = result
168183

184+
# S0
185+
b0 = b0_im.get_data()
186+
for i in xrange(ndirs):
187+
simulated[..., i] *= b0
188+
169189
simhdr = hdr.copy()
170190
simhdr.set_data_dtype(np.float32)
171191
simhdr.set_xyzt_units('mm', 'sec')
@@ -202,42 +222,41 @@ def _compute_voxel(args):
202222
D_ball = [3000e-6, 960e-6, 680e-6]
203223
sf_evals = [1700e-6, 200e-6, 200e-6]
204224

205-
vfs = [args[0], args[1], args[2]]
206-
ffs = [args[3], args[4], args[5]] # single fiber fractions
207-
sticks = [(args[6], args[7], args[8]),
208-
(args[8], args[10], args[11]),
209-
(args[12], args[13], args[14])]
225+
nf = args[0] # number of fibers
226+
gtab = args[1] # gradient table
227+
snr = args[2]
228+
vfs = args[3:6]
229+
230+
vfs = (np.array(vfs) / np.sum(vfs))
231+
232+
sst = 6 + nf
233+
ffs = args[6:sst] # single fiber fractions
210234

211-
S0 = args[15]
212-
gtab = args[16]
235+
sticks = [tuple(args[sst + i * 3:sst + 3 + i * 3])
236+
for i in range(0, nf)]
213237

214-
nf = len(ffs)
215238
mevals = [sf_evals] * nf
216239
sf_vf = np.sum(ffs)
217-
ffs = ((np.array(ffs) / sf_vf) * 100)
218240

219241
# Simulate sticks
220-
signal, _ = multi_tensor(gtab, np.array(mevals), S0=1,
221-
angles=sticks, fractions=ffs, snr=None)
222-
signal *= sf_vf
242+
if sf_vf > 1.0e-3:
243+
ffs = ((np.array(ffs) / sf_vf) * 100)
244+
signal, _ = multi_tensor(gtab, np.array(mevals), S0=1.0,
245+
angles=sticks, fractions=ffs, snr=None)
246+
else:
247+
signal = np.zeros_like(gtab.bvals, dtype=np.float32)
248+
249+
signal *= vfs[2] * sf_vf
223250

224251
# Simulate balls
225-
r = 1.0 - sf_vf
226-
if r > 1.0e-3:
227-
for vf, d in zip(vfs, D_ball):
228-
f0 = vf * r
229-
signal += f0 * np.exp(-gtab.bvals * d)
230-
231-
snr = None
232-
try:
233-
snr = args[17]
234-
except IndexError:
235-
pass
236-
237-
if snr is not None and snr >= 0:
238-
signal[1:] = add_noise(signal[1:], snr, 1)
239-
240-
return signal * S0
252+
vfs[2] *= (1 - sf_vf)
253+
for f0, d in zip(vfs, D_ball):
254+
signal += f0 * np.exp(-gtab.bvals * d)
255+
256+
if snr > 0:
257+
signal = add_noise(signal, snr, 1)
258+
259+
return signal.tolist()
241260

242261

243262
def _generate_gradients(ndirs=64, values=[1000, 3000], nb0s=1):

0 commit comments

Comments
 (0)