Skip to content

Commit a6338da

Browse files
committed
delegate isotropic diffusion simulation to dipy
1 parent 92c7c72 commit a6338da

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

nipype/interfaces/dipy/simulate.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ class SimulateMultiTensorInputSpec(BaseInterfaceInputSpec):
4545
in_mask = File(exists=True, desc='mask to simulate data')
4646

4747
diff_iso = traits.List(
48-
traits.Float, default=[3000e-6, 960e-6, 680e-6], usedefault=True,
48+
[3000e-6, 960e-6, 680e-6], traits.Float, usedefault=True,
4949
desc='Diffusivity of isotropic compartments')
5050
diff_sf = traits.Tuple(
51-
traits.Float, traits.Float, traits.Float,
52-
default=(1700e-6, 200e-6, 200e-6), usedefault=True,
51+
(1700e-6, 200e-6, 200e-6),
52+
traits.Float, traits.Float, traits.Float, usedefault=True,
5353
desc='Single fiber tensor')
5454

5555
n_proc = traits.Int(0, usedefault=True, desc='number of processes')
@@ -128,14 +128,35 @@ def _run_interface(self, runtime):
128128
raise RuntimeError(('Number of sticks and their volume fractions'
129129
' must match.'))
130130

131-
ffsim = nb.concat_images([nb.load(f) for f in self.inputs.in_frac])
132-
ffs = np.squeeze(ffsim.get_data()) # fiber fractions
133-
ffs[ffs > 1.0] = 1.0
134-
ffs[ffs < 0.0] = 0.0
131+
# Volume fractions of isotropic compartments
132+
nballs = len(self.inputs.in_vfms)
133+
vfs = np.squeeze(nb.concat_images([nb.load(f) for f in self.inputs.in_vfms]).get_data())
134+
if nballs == 1:
135+
vfs = vfs[..., np.newaxis]
136+
total_vf = np.sum(vfs, axis=3)
135137

138+
# Generate a mask
139+
if isdefined(self.inputs.in_mask):
140+
msk = nb.load(self.inputs.in_mask).get_data()
141+
msk[msk > 0.0] = 1.0
142+
msk[msk < 1.0] = 0.0
143+
else:
144+
msk = np.zeros(shape)
145+
msk[total_vf > 0.0] = 1.0
146+
147+
msk = np.clip(msk, 0.0, 1.0)
148+
nvox = len(msk[msk > 0])
149+
150+
# Fiber fractions
151+
ffsim = nb.concat_images([nb.load(f) for f in self.inputs.in_frac])
152+
ffs = np.nan_to_num(np.squeeze(ffsim.get_data())) # fiber fractions
153+
ffs = np.clip(ffs, 0., 1.)
136154
if nsticks == 1:
137155
ffs = ffs[..., np.newaxis]
138156

157+
for i in range(nsticks):
158+
ffs[..., i] *= msk
159+
139160
total_ff = np.sum(ffs, axis=3)
140161

141162
# Fix incongruencies in fiber fractions
@@ -147,33 +168,14 @@ def _run_interface(self, runtime):
147168
ffs[ffs < 0.0] = 0.0
148169
total_ff = np.sum(ffs, axis=3)
149170

150-
# Volume fractions of isotropic compartiments
151-
nballs = len(self.inputs.in_vfms)
152-
vfs = np.squeeze(nb.concat_images([nb.load(f) for f in self.inputs.in_vfms]).get_data())
153-
if nsticks == 1:
154-
vfs = vfs[..., np.newaxis]
155-
156-
157171
for i in range(vfs.shape[-1]):
158172
vfs[..., i] -= total_ff
159-
vfs[vfs < 0.0] = 0
173+
vfs = np.clip(vfs, 0., 1.)
160174

161175
fractions = np.concatenate((ffs, vfs), axis=3)
162-
total_vf = np.sum(fractions, axis=3)
163176
nb.Nifti1Image(fractions, aff, None).to_filename('fractions.nii.gz')
164177
nb.Nifti1Image(total_vf, aff, None).to_filename('total_vf.nii.gz')
165178

166-
# Generate a mask
167-
if isdefined(self.inputs.in_mask):
168-
msk = nb.load(self.inputs.in_mask).get_data()
169-
msk[msk > 0.0] = 1.0
170-
msk[msk < 1.0] = 0.0
171-
else:
172-
msk = np.zeros(shape, dtype=np.uint8)
173-
msk[total_vf > 0.0] = 1
174-
175-
nvox = len(mask[mask > 0])
176-
177179
mhdr = hdr.copy()
178180
mhdr.set_data_dtype(np.uint8)
179181
mhdr.set_xyzt_units('mm', 'sec')
@@ -194,19 +196,18 @@ def _run_interface(self, runtime):
194196

195197

196198
sf_evals = list(self.inputs.diff_sf)
197-
ba_evals = self.inputs.diff_iso
199+
ba_evals = list(self.inputs.diff_iso)
198200

201+
mevals = [sf_evals] * nsticks + [[ba_evals[d]]*3 for d in range(nballs)]
199202
args = []
200203
for i in range(nvox):
201204
args.append(
202205
{'fractions': fracs[i, ...].tolist(),
203-
'sticks': [(1.0, 0.0, 0.0)] * nballs + dirs[i, ...].tolist(),
206+
'sticks': [tuple(dirs[i, j:j+3]) for j in range(nsticks)] + [(1.0, 0.0, 0.0)] * nballs,
204207
'gradients': gtab,
205-
'mevals': [[ba_evals[d]*3] for d in range(nballs)] + [sf_evals] * nsticks
208+
'mevals': mevals
206209
})
207210

208-
print args[:5]
209-
210211
n_proc = self.inputs.n_proc
211212
if n_proc == 0:
212213
n_proc = cpu_count()

0 commit comments

Comments
 (0)