Skip to content

Commit ad814d9

Browse files
dPysoesteban
authored andcommitted
enh: add a data-driven *b = 0* estimator
1 parent b9c20c4 commit ad814d9

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

dmriprep/interfaces/vectors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class _CheckGradientTableInputSpec(BaseInterfaceInputSpec):
2121
b0_threshold = traits.Float(B0_THRESHOLD, usedefault=True)
2222
bvec_norm_epsilon = traits.Float(BVEC_NORM_EPSILON, usedefault=True)
2323
b_scale = traits.Bool(True, usedefault=True)
24+
image_consistency = traits.Bool(False, usedefault=True)
2425

2526

2627
class _CheckGradientTableOutputSpec(TraitedSpec):
@@ -75,8 +76,9 @@ def _run_interface(self, runtime):
7576
bvals=_undefined(self.inputs, "in_bval"),
7677
rasb_file=rasb_file,
7778
b_scale=self.inputs.b_scale,
79+
image_consistency=self.inputs.image_consistency,
7880
bvec_norm_epsilon=self.inputs.bvec_norm_epsilon,
79-
b0_threshold=self.inputs.b0_threshold,
81+
b0_threshold=self.inputs.b0_threshold
8082
)
8183
pole = table.pole
8284
self._results["pole"] = tuple(pole)

dmriprep/utils/tests/test_vectors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def test_corruption(tmpdir, dipy_test_data, monkeypatch):
8686
assert -1.0 <= np.max(np.abs(dgt.gradients[..., :-1])) <= 1.0
8787
assert dgt.normalized is True
8888

89+
# Test image gradient consistency
90+
dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'],
91+
bvals=bvals, bvecs=bvecs)
92+
assert dgt.gradient_consistency is None
93+
8994
def mock_func(*args, **kwargs):
9095
return "called!"
9196

dmriprep/utils/vectors.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103

104104
if dwi_file is not None:
105105
self.affine = dwi_file
106+
self._dwi_file = dwi_file
106107
if rasb_file is not None:
107108
self.gradients = rasb_file
108109
if self.affine is not None:
@@ -278,6 +279,18 @@ def to_filename(self, filename, filetype="rasb"):
278279
else:
279280
raise ValueError('Unknown filetype "%s"' % filetype)
280281

282+
# @property
283+
# def gradient_consistency(self):
284+
# """
285+
# Check that b gradients and signal variation in dwi image are consistent with one another.
286+
# """
287+
# if (Path(self._dwi_file).exists() is True) and (self._image_consistency is True):
288+
# return image_gradient_consistency_check(str(self._dwi_file), self.bvecs, self.bvals,
289+
# b0_threshold=self._b0_thres)
290+
# else:
291+
# raise FileNotFoundError(
292+
# "DWI file not found, and is required for checking image-gradient consistency.")
293+
281294

282295
def normalize_gradients(
283296
bvecs,
@@ -483,3 +496,75 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2):
483496
rotated_bvecs[~b0s] /= norms_bvecs[~b0s, np.newaxis]
484497
rotated_bvecs[b0s] = np.zeros(3)
485498
return rotated_bvecs
499+
500+
501+
def image_gradient_consistency_check(dwi_file, bvecs, bvals, b0_threshold=B0_THRESHOLD):
502+
"""
503+
Check that gradient encoding patterns found in the b-values correspond to those
504+
found in the signal intensity variance across volumes of the dwi image.
505+
506+
Parameters
507+
----------
508+
dwi_file : str
509+
Optionally provide a file path to the diffusion-weighted image series to which the
510+
bvecs/bvals should correspond.
511+
bvecs : m x n 2d array
512+
B-vectors array.
513+
bvals : 1d array
514+
B-values float array.
515+
b0_threshold : float
516+
Gradient threshold below which volumes and vectors are considered B0's.
517+
"""
518+
import nibabel as nib
519+
from dipy.core.gradients import gradient_table_from_bvals_bvecs
520+
from sklearn.cluster import MeanShift, estimate_bandwidth
521+
522+
# Build gradient table object
523+
gtab = gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=b0_threshold)
524+
525+
# Check that number of image volumes corresponds to the number of gradient encodings.
526+
volume_num = nib.load(dwi_file).shape[-1]
527+
if not len(bvals) == volume_num:
528+
raise Exception("Expected %d total image samples but found %d total gradient encoding values", volume_num,
529+
len(bvals))
530+
# Evaluate B0 locations relative to mean signal variation.
531+
data = np.array(nib.load(dwi_file).dataobj)
532+
signal_means = []
533+
for vol in range(data.shape[-1]):
534+
signal_means.append(np.mean(data[:, :, :, vol]))
535+
signal_b0_indices = np.where(signal_means > np.mean(signal_means)+3*np.std(signal_means))[0]
536+
537+
# Check B0 locations first
538+
if not np.allclose(np.where(gtab.b0s_mask == True)[0], signal_b0_indices):
539+
raise UserWarning('B0 indices in vectors do not correspond to relative high-signal contrast volumes '
540+
'detected in dwi image.')
541+
542+
# Next check number of unique b encodings (i.e. shells) and their indices
543+
X = np.array(signal_means).reshape(-1, 1)
544+
ms = MeanShift(bandwidth=estimate_bandwidth(X, quantile=0.1), bin_seeding=True)
545+
ms.fit(X)
546+
labs, idx = np.unique(ms.labels_, return_index=True)
547+
ix_img = []
548+
i = -1
549+
for val in range(len(ms.labels_)):
550+
if val in np.sort(idx[::-1]):
551+
i = i + 1
552+
ix_img.append(labs[i])
553+
ix_img = np.array(ix_img)
554+
555+
ix_vec = []
556+
i = -1
557+
for val in range(len(bvals)):
558+
if bvals[val] != bvals[val-1]:
559+
i = i + 1
560+
ix_vec.append(i)
561+
ix_vec = np.array(ix_vec)
562+
563+
if len(ms.cluster_centers_) != len(np.unique(bvals)):
564+
raise UserWarning('Number of unique b-values does not correspond to number of unique signal gradient '
565+
'encoding intensities in the dwi image.')
566+
567+
if np.all(np.isclose(ix_img, ix_vec)) is True:
568+
raise UserWarning('Positions of b-value B0\'s and shell(s) do not correspond to signal intensity '
569+
'fluctuation patterns in the dwi image.')
570+
return

0 commit comments

Comments
 (0)