Skip to content

Commit 3aec266

Browse files
authored
Merge pull request #89 from arokem/reorient_vectors_arokem
ENH: Update DiffusionGradientTable interface to support vector reorientation
2 parents 23d8db4 + 2f21dd5 commit 3aec266

File tree

2 files changed

+111
-23
lines changed

2 files changed

+111
-23
lines changed

dmriprep/interfaces/vectors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
SimpleInterface, BaseInterfaceInputSpec, TraitedSpec,
77
File, traits, isdefined
88
)
9-
from ..utils.vectors import DiffusionGradientTable, B0_THRESHOLD, BVEC_NORM_EPSILON
9+
from ..utils.vectors import (DiffusionGradientTable, B0_THRESHOLD,
10+
BVEC_NORM_EPSILON
11+
)
1012

1113

1214
class _CheckGradientTableInputSpec(BaseInterfaceInputSpec):

dmriprep/utils/vectors.py

Lines changed: 108 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,29 @@
1212
class DiffusionGradientTable:
1313
"""Data structure for DWI gradients."""
1414

15-
__slots__ = ['_affine', '_gradients', '_b_scale', '_bvecs', '_bvals', '_normalized',
16-
'_b0_thres', '_bvec_norm_epsilon']
17-
18-
def __init__(self, dwi_file=None, bvecs=None, bvals=None, rasb_file=None,
19-
b_scale=True, b0_threshold=B0_THRESHOLD, bvec_norm_epsilon=BVEC_NORM_EPSILON):
15+
__slots__ = [
16+
"_affine",
17+
"_gradients",
18+
"_b_scale",
19+
"_bvecs",
20+
"_bvals",
21+
"_normalized",
22+
"_transforms",
23+
"_b0_thres",
24+
"_bvec_norm_epsilon",
25+
]
26+
27+
def __init__(
28+
self,
29+
dwi_file=None,
30+
bvecs=None,
31+
bvals=None,
32+
rasb_file=None,
33+
b_scale=True,
34+
transforms=None,
35+
b0_threshold=B0_THRESHOLD,
36+
bvec_norm_epsilon=BVEC_NORM_EPSILON,
37+
):
2038
"""
2139
Create a new table of diffusion gradients.
2240
@@ -35,7 +53,29 @@ def __init__(self, dwi_file=None, bvecs=None, bvals=None, rasb_file=None,
3553
b_scale : bool
3654
Whether b-values should be normalized.
3755
56+
Example
57+
-------
58+
>>> os.chdir(tmpdir)
59+
>>> old_rasb = str(data_dir / 'dwi.tsv')
60+
>>> old_rasb_mat = np.loadtxt(str(data_dir / 'dwi.tsv'), skiprows=1)
61+
>>> from dmriprep.utils.vectors import bvecs2ras
62+
>>> check = DiffusionGradientTable(
63+
... dwi_file=str(data_dir / 'dwi.nii.gz'),
64+
... rasb_file=str(data_dir / 'dwi.tsv'))
65+
>>> # Conform to the orientation of the image:
66+
>>> old_rasb_mat[:, :3] = bvecs2ras(check.affine, old_rasb_mat[:, :3])
67+
>>> affines = np.zeros((old_rasb_mat.shape[0], 4, 4))
68+
>>> aff_file_list = []
69+
>>> for ii, aff in enumerate(affines):
70+
... aff_file = f'aff_{ii}.npy'
71+
... np.save(aff_file, aff)
72+
... aff_file_list.append(aff_file)
73+
>>> check._transforms = aff_file_list
74+
>>> out_rasb_mat = check.reorient_rasb()
75+
>>> np.allclose(old_rasb_mat, out_rasb_mat)
76+
True
3877
"""
78+
self._transforms = transforms
3979
self._b_scale = b_scale
4080
self._b0_thres = b0_threshold
4181
self._bvec_norm_epsilon = bvec_norm_epsilon
@@ -87,7 +127,7 @@ def affine(self, value):
87127
dwi_file = nb.load(str(value))
88128
self._affine = dwi_file.affine.copy()
89129
return
90-
if hasattr(value, 'affine'):
130+
if hasattr(value, "affine"):
91131
self._affine = value.affine
92132
self._affine = np.array(value)
93133

@@ -102,20 +142,20 @@ def bvecs(self, value):
102142
if isinstance(value, (str, Path)):
103143
value = np.loadtxt(str(value)).T
104144
else:
105-
value = np.array(value, dtype='float32')
145+
value = np.array(value, dtype="float32")
106146

107147
# Correct any b0's in bvecs misstated as 10's.
108148
value[np.any(abs(value) >= 10, axis=1)] = np.zeros(3)
109149
if self.bvals is not None and value.shape[0] != self.bvals.shape[0]:
110-
raise ValueError('The number of b-vectors and b-values do not match')
150+
raise ValueError("The number of b-vectors and b-values do not match")
111151
self._bvecs = value
112152

113153
@bvals.setter
114154
def bvals(self, value):
115155
if isinstance(value, (str, Path)):
116156
value = np.loadtxt(str(value)).flatten()
117157
if self.bvecs is not None and value.shape[0] != self.bvecs.shape[0]:
118-
raise ValueError('The number of b-vectors and b-values do not match')
158+
raise ValueError("The number of b-vectors and b-values do not match")
119159
self._bvals = np.array(value)
120160

121161
@property
@@ -129,10 +169,12 @@ def normalize(self):
129169
return
130170

131171
self._bvecs, self._bvals = normalize_gradients(
132-
self.bvecs, self.bvals,
172+
self.bvecs,
173+
self.bvals,
133174
b0_threshold=self._b0_thres,
134175
bvec_norm_epsilon=self._bvec_norm_epsilon,
135-
b_scale=self._b_scale)
176+
b_scale=self._b_scale,
177+
)
136178
self._normalized = True
137179

138180
def generate_rasb(self):
@@ -142,14 +184,52 @@ def generate_rasb(self):
142184
_ras = bvecs2ras(self.affine, self.bvecs)
143185
self.gradients = np.hstack((_ras, self.bvals[..., np.newaxis]))
144186

187+
def reorient_rasb(self):
188+
"""Reorient the vectors based o a list of affine transforms."""
189+
from dipy.core.gradients import (gradient_table_from_bvals_bvecs,
190+
reorient_bvecs)
191+
192+
affines = self._transforms.copy()
193+
bvals = self._bvals
194+
bvecs = self._bvecs
195+
196+
# Verify that number of non-B0 volumes corresponds to the number of
197+
# affines. If not, try to fix it, or raise an error:
198+
if len(self._bvals[self._bvals >= self._b0_thres]) != len(affines):
199+
b0_indices = np.where(self._bvals <= self._b0_thres)[0].tolist()
200+
if len(self._bvals[self._bvals >= self._b0_thres]) < len(affines):
201+
for i in sorted(b0_indices, reverse=True):
202+
del affines[i]
203+
if len(self._bvals[self._bvals >= self._b0_thres]) > len(affines):
204+
ras_b_mat = self._gradients.copy()
205+
ras_b_mat = np.delete(ras_b_mat, tuple(b0_indices), axis=0)
206+
bvals = ras_b_mat[:, 3]
207+
bvecs = ras_b_mat[:, 0:3]
208+
if len(self._bvals[self._bvals > self._b0_thres]) != len(affines):
209+
raise ValueError(
210+
"Affine transformations do not correspond to gradients"
211+
)
212+
213+
# Build gradient table object
214+
gt = gradient_table_from_bvals_bvecs(bvals, bvecs,
215+
b0_threshold=self._b0_thres)
216+
217+
# Reorient table
218+
new_gt = reorient_bvecs(gt, [np.load(aff) for aff in affines])
219+
220+
return np.hstack((new_gt.bvecs, new_gt.bvals[..., np.newaxis]))
221+
145222
def generate_vecval(self):
146223
"""Compose a bvec/bval pair in image coordinates."""
147224
if self.bvecs is None or self.bvals is None:
148225
if self.affine is None:
149226
raise TypeError(
150227
"Cannot generate b-vectors & b-values in image coordinates. "
151-
"Please set the corresponding DWI image's affine matrix.")
152-
self._bvecs = bvecs2ras(np.linalg.inv(self.affine), self.gradients[..., :-1])
228+
"Please set the corresponding DWI image's affine matrix."
229+
)
230+
self._bvecs = bvecs2ras(
231+
np.linalg.inv(self.affine), self.gradients[..., :-1]
232+
)
153233
self._bvals = self.gradients[..., -1].flatten()
154234

155235
@property
@@ -161,19 +241,25 @@ def pole(self):
161241
162242
"""
163243
self.generate_rasb()
164-
return calculate_pole(self.gradients[..., :-1], bvec_norm_epsilon=self._bvec_norm_epsilon)
244+
return calculate_pole(
245+
self.gradients[..., :-1], bvec_norm_epsilon=self._bvec_norm_epsilon
246+
)
165247

166-
def to_filename(self, filename, filetype='rasb'):
248+
def to_filename(self, filename, filetype="rasb"):
167249
"""Write files (RASB, bvecs/bvals) to a given path."""
168-
if filetype.lower() == 'rasb':
250+
if filetype.lower() == "rasb":
169251
self.generate_rasb()
170-
np.savetxt(str(filename), self.gradients,
171-
delimiter='\t', header='\t'.join('RASB'),
172-
fmt=['%.8f'] * 3 + ['%g'])
173-
elif filetype.lower() == 'fsl':
252+
np.savetxt(
253+
str(filename),
254+
self.gradients,
255+
delimiter="\t",
256+
header="\t".join("RASB"),
257+
fmt=["%.8f"] * 3 + ["%g"],
258+
)
259+
elif filetype.lower() == "fsl":
174260
self.generate_vecval()
175-
np.savetxt('%s.bvec' % filename, self.bvecs.T, fmt='%.6f')
176-
np.savetxt('%s.bval' % filename, self.bvals, fmt='%.6f')
261+
np.savetxt("%s.bvec" % filename, self.bvecs.T, fmt="%.6f")
262+
np.savetxt("%s.bval" % filename, self.bvals, fmt="%.6f")
177263
else:
178264
raise ValueError('Unknown filetype "%s"' % filetype)
179265

0 commit comments

Comments
 (0)