Skip to content

Commit c73fb80

Browse files
committed
add automated generation of gradient table, add sticks&balls
1 parent 29bb49c commit c73fb80

File tree

1 file changed

+134
-42
lines changed

1 file changed

+134
-42
lines changed

nipype/interfaces/dipy/simulate.py

Lines changed: 134 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
from nipype.interfaces.base import (
10-
TraitedSpec, BaseInterface, BaseInterfaceInputSpec, File,
10+
traits, TraitedSpec, BaseInterface, BaseInterfaceInputSpec, File,
1111
InputMultiPath, isdefined)
1212
from nipype.utils.filemanip import split_filename
1313
import os.path as op
@@ -16,6 +16,9 @@
1616
from nipype.utils.misc import package_check
1717
import warnings
1818

19+
from multiprocessing import (Process, Pool, cpu_count, pool,
20+
Manager, TimeoutError)
21+
1922
from ... import logging
2023
iflogger = logging.getLogger('interface')
2124

@@ -28,7 +31,7 @@
2831
import numpy as np
2932
from dipy.sims.voxel import (multi_tensor,
3033
all_tensor_evecs)
31-
from dipy.core.gradients import GradientTable
34+
from dipy.core.gradients import gradient_table
3235

3336

3437
class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
@@ -40,19 +43,30 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
4043
desc='volume fraction map')
4144
in_mask = File(exists=True, desc='mask to simulate data')
4245

46+
n_proc = traits.Int(0, usedefault=True, desc='number of processes')
4347
baseline = File(exists=True, mandatory=True, desc='baseline T2 signal')
4448
gradients = File(exists=True, desc='gradients file')
45-
bvec = File(exists=True, mandatory=True, desc='bvecs file')
46-
bval = File(exists=True, mandatory=True, desc='bvals file')
49+
bvec = File(exists=True, desc='bvecs file')
50+
bval = File(exists=True, desc='bvals file')
51+
num_dirs = traits.Int(32, usedefault=True,
52+
desc=('number of gradient directions (when table '
53+
'is automatically generated)'))
54+
bvalues = traits.List(traits.Int, value=[1000, 3000], usedefault=True,
55+
desc=('list of b-values (when table '
56+
'is automatically generated)'))
4757
out_file = File('sim_dwi.nii.gz', usedefault=True,
4858
desc='output file with fractions to be simluated')
4959
out_mask = File('sim_msk.nii.gz', usedefault=True,
5060
desc='file with the mask simulated')
61+
out_bvec = File('bvec.sim', usedefault=True, desc='simulated b vectors')
62+
out_bval = File('bval.sim', usedefault=True, desc='simulated b values')
5163

5264

5365
class SimulateMultiTensorOutputSpec(TraitedSpec):
5466
out_file = File(exists=True, desc='simulated DWIs')
5567
out_mask = File(exists=True, desc='mask file')
68+
out_bvec = File(exists=True, desc='simulated b vectors')
69+
out_bval = File(exists=True, desc='simulated b values')
5670

5771

5872
class SimulateMultiTensor(BaseInterface):
@@ -82,18 +96,13 @@ def _run_interface(self, runtime):
8296
hdr = b0_im.get_header()
8397
shape = b0_im.get_shape()
8498
aff = b0_im.get_affine()
85-
b0 = b0_im.get_data().reshape(-1)
8699

87100
ffsim = nb.concat_images([nb.load(f) for f in self.inputs.in_frac])
88101
ffs = np.squeeze(ffsim.get_data()) # fiber fractions
89102

90103
vfsim = nb.concat_images([nb.load(f) for f in self.inputs.in_vfms])
91104
vfs = np.squeeze(vfsim.get_data()) # volume fractions
92105

93-
# Load structural files
94-
thetas = []
95-
phis = []
96-
97106
total_ff = np.sum(ffs, axis=3)
98107
total_vf = np.sum(vfs, axis=3)
99108

@@ -111,50 +120,133 @@ def _run_interface(self, runtime):
111120
nb.Nifti1Image(msk, aff, mhdr).to_filename(
112121
op.abspath(self.inputs.out_mask))
113122

123+
args = np.hstack((vfs[msk > 0], ffs[msk > 0]))
124+
114125
for f in self.inputs.in_dirs:
115126
fd = nb.load(f).get_data()
116-
x = fd[msk > 0][..., 0]
117-
y = fd[msk > 0][..., 1]
118-
z = fd[msk > 0][..., 2]
119-
th = np.arccos(z / np.sqrt(x ** 2 + y ** 2 + z ** 2))
120-
ph = np.arctan2(y, x)
121-
thetas.append(th)
122-
phis.append(ph)
123-
124-
# Load the gradient strengths and directions
125-
bvals = np.loadtxt(self.inputs.bval)
126-
gradients = np.loadtxt(self.inputs.bvec).T
127-
128-
# Place in Dipy's preferred format
129-
gtab = GradientTable(gradients)
130-
gtab.bvals = bvals
127+
args = np.hstack((args, fd[msk > 0]))
128+
129+
b0 = np.array([b0_im.get_data()[msk > 0]]).T
130+
args = np.hstack((args, b0))
131+
132+
if isdefined(self.inputs.bval) and isdefined(self.inputs.bvec):
133+
# Load the gradient strengths and directions
134+
bvals = np.loadtxt(self.inputs.bval)
135+
bvecs = np.loadtxt(self.inputs.bvec).T
136+
137+
# Place in Dipy's preferred format
138+
gtab = gradient_table(bvals, bvecs)
139+
else:
140+
gtab = _generate_gradients(self.inputs.num_dirs,
141+
self.inputs.bvalues)
142+
143+
np.savetxt(op.abspath(self.inputs.out_bvec), gtab.bvecs.T)
144+
np.savetxt(op.abspath(self.inputs.out_bval), gtab.bvals.T)
145+
146+
args = [tuple(np.hstack((r, gtab))) for r in args]
147+
148+
n_proc = self.inputs.n_proc
149+
if n_proc == 0:
150+
n_proc = cpu_count()
151+
152+
try:
153+
pool = Pool(processes=n_proc, maxtasksperchild=50)
154+
except TypeError:
155+
pool = Pool(processes=n_proc)
156+
157+
iflogger.info('Starting simulation of %d voxels' % len(args))
158+
result = pool.map(_compute_voxel, args)
159+
ndirs = np.shape(result)[1]
160+
161+
simulated = np.zeros((shape[0], shape[1], shape[2], ndirs))
162+
simulated[msk > 0] = result
163+
164+
simhdr = hdr.copy()
165+
simhdr.set_data_dtype(np.float32)
166+
simhdr.set_xyzt_units('mm', 'sec')
167+
nb.Nifti1Image(simulated.astype(np.float32), aff,
168+
simhdr).to_filename(op.abspath(self.inputs.out_file))
131169

132170
return runtime
133171

134172
def _list_outputs(self):
135173
outputs = self._outputs().get()
136174
outputs['out_file'] = op.abspath(self.inputs.out_file)
137175
outputs['out_mask'] = op.abspath(self.inputs.out_mask)
176+
outputs['out_bvec'] = op.abspath(self.inputs.out_bvec)
177+
outputs['out_bval'] = op.abspath(self.inputs.out_bval)
178+
138179
return outputs
139180

140181

141-
def _compute_voxel(vfs, ffs, ths, phs, S0, gtab, snr=None,
142-
csf_evals=[0.0015, 0.0015, 0.0015],
143-
gm_evals=[0.0007, 0.0007, 0.0007],
144-
wm_evals=[0.0015, 0.0003, 0.0003]):
182+
def _compute_voxel(args):
183+
D_ball = [3000e-6, 960e-6, 680e-6]
184+
sf_evals = [1700e-6, 200e-6, 200e-6]
185+
186+
vfs = [args[0], args[1], args[2]]
187+
ffs = [args[3], args[4], args[5]] # single fiber fractions
188+
sticks = [(args[6], args[7], args[8]),
189+
(args[8], args[10], args[11]),
190+
(args[12], args[13], args[14])]
191+
192+
S0 = args[15]
193+
gtab = args[16]
145194

146195
nf = len(ffs)
147-
total_ff = np.sum(ffs)
148-
149-
gm_vf = vfs[1] * (1 - total_ff) / (vfs[0] + vfs[1])
150-
ffs.insert(0, gm_vf)
151-
csf_vf = vfs[0] * (1 - total_ff) / (vfs[0] + vfs[1])
152-
ffs.insert(0, csf_vf)
153-
angles = [(0, 0), (0, 0)] # angles of gm and csf
154-
angles += [(th, ph) for ph, th in zip(ths, phs)]
155-
156-
mevals = np.array([csf_evals, gm_evals] + [wm_evals] * nf)
157-
ffs = np.array(ffs) * 100
158-
signal, sticks = multi_tensor(gtab, mevals, S0=S0, angles=angles,
159-
fractions=ffs, snr=snr)
160-
return signal, sticks
196+
mevals = [sf_evals] * nf
197+
sf_vf = np.sum(ffs)
198+
ffs = ((np.array(ffs) / sf_vf) * 100)
199+
200+
# Simulate sticks
201+
signal, _ = multi_tensor(gtab, np.array(mevals), S0=1,
202+
angles=sticks, fractions=ffs, snr=None)
203+
signal *= sf_vf
204+
205+
# Simulate balls
206+
r = 1.0 - sf_vf
207+
if r > 1.0e-3:
208+
for vf, d in zip(vfs, D_ball):
209+
f0 = vf * r
210+
signal += f0 * np.exp(-gtab.bvals * d)
211+
212+
snr = None
213+
try:
214+
snr = args[17]
215+
except IndexError:
216+
pass
217+
218+
return signal * S0
219+
220+
221+
def _generate_gradients(ndirs=64, values=[1000, 3000], nb0s=1):
222+
"""
223+
Automatically generate a `gradient table
224+
<http://nipy.org/dipy/examples_built/gradients_spheres.html#example-gradients-spheres>`_
225+
226+
"""
227+
import numpy as np
228+
from dipy.core.sphere import (disperse_charges, Sphere, HemiSphere)
229+
from dipy.core.gradients import gradient_table
230+
231+
theta = np.pi * np.random.rand(ndirs)
232+
phi = 2 * np.pi * np.random.rand(ndirs)
233+
hsph_initial = HemiSphere(theta=theta, phi=phi)
234+
hsph_updated, potential = disperse_charges(hsph_initial, 5000)
235+
236+
values = np.atleast_1d(values).tolist()
237+
vertices = hsph_updated.vertices
238+
bvecs = vertices.copy()
239+
bvals = np.ones(vertices.shape[0]) * values[0]
240+
241+
for v in values[1:]:
242+
bvecs = np.vstack((bvecs, vertices))
243+
bvals = np.hstack((bvals, v * np.ones(vertices.shape[0])))
244+
245+
for i in xrange(0, nb0s):
246+
bvals = bvals.tolist()
247+
bvals.insert(0, 0)
248+
249+
bvecs = bvecs.tolist()
250+
bvecs.insert(0, np.zeros(3))
251+
252+
return gradient_table(bvals, bvecs)

0 commit comments

Comments
 (0)