Skip to content

Commit 99d008f

Browse files
committed
enh: focus the PR in only one function, add test
Replaces: #29
1 parent ad814d9 commit 99d008f

File tree

3 files changed

+48
-86
lines changed

3 files changed

+48
-86
lines changed

dmriprep/interfaces/vectors.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class _CheckGradientTableInputSpec(BaseInterfaceInputSpec):
2121
b0_threshold = traits.Float(B0_THRESHOLD, usedefault=True)
2222
bvec_norm_epsilon = traits.Float(BVEC_NORM_EPSILON, usedefault=True)
2323
b_scale = traits.Bool(True, usedefault=True)
24-
image_consistency = traits.Bool(False, usedefault=True)
2524

2625

2726
class _CheckGradientTableOutputSpec(TraitedSpec):
@@ -76,9 +75,8 @@ def _run_interface(self, runtime):
7675
bvals=_undefined(self.inputs, "in_bval"),
7776
rasb_file=rasb_file,
7877
b_scale=self.inputs.b_scale,
79-
image_consistency=self.inputs.image_consistency,
8078
bvec_norm_epsilon=self.inputs.bvec_norm_epsilon,
81-
b0_threshold=self.inputs.b0_threshold
79+
b0_threshold=self.inputs.b0_threshold,
8280
)
8381
pole = table.pole
8482
self._results["pole"] = tuple(pole)

dmriprep/utils/tests/test_vectors.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test vector utilities."""
22
import pytest
33
import numpy as np
4+
import nibabel as nb
45
from dmriprep.utils import vectors as v
56
from collections import namedtuple
67

@@ -86,11 +87,6 @@ def test_corruption(tmpdir, dipy_test_data, monkeypatch):
8687
assert -1.0 <= np.max(np.abs(dgt.gradients[..., :-1])) <= 1.0
8788
assert dgt.normalized is True
8889

89-
# Test image gradient consistency
90-
dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'],
91-
bvals=bvals, bvecs=bvecs)
92-
assert dgt.gradient_consistency is None
93-
9490
def mock_func(*args, **kwargs):
9591
return "called!"
9692

@@ -105,3 +101,25 @@ def mock_func(*args, **kwargs):
105101
# Miscellaneous tests
106102
with pytest.raises(ValueError):
107103
dgt.to_filename("path", filetype="mrtrix")
104+
105+
106+
def test_b0mask_from_data(tmp_path):
107+
"""Check the estimation of bzeros using the dwi data."""
108+
109+
highb = np.random.normal(100, 50, size=(40, 40, 40, 99))
110+
mask_file = tmp_path / 'mask.nii.gz'
111+
112+
# Test 1: no lowb
113+
dwi_file = tmp_path / 'only_highb.nii.gz'
114+
nb.Nifti1Image(highb.astype(float), np.eye(4), None).to_filename(dwi_file)
115+
nb.Nifti1Image(np.ones((40, 40, 40), dtype=np.uint8),
116+
np.eye(4), None).to_filename(mask_file)
117+
118+
assert v.b0mask_from_data(dwi_file, mask_file).sum() == 0
119+
120+
# Test 1: one lowb
121+
lowb = np.random.normal(400, 50, size=(40, 40, 40, 1))
122+
dwi_file = tmp_path / 'dwi.nii.gz'
123+
nb.Nifti1Image(np.concatenate((lowb, highb), axis=3).astype(float),
124+
np.eye(4), None).to_filename(dwi_file)
125+
assert v.b0mask_from_data(dwi_file, mask_file).sum() == 1

dmriprep/utils/vectors.py

Lines changed: 24 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def __init__(
103103

104104
if dwi_file is not None:
105105
self.affine = dwi_file
106-
self._dwi_file = dwi_file
107106
if rasb_file is not None:
108107
self.gradients = rasb_file
109108
if self.affine is not None:
@@ -279,18 +278,6 @@ def to_filename(self, filename, filetype="rasb"):
279278
else:
280279
raise ValueError('Unknown filetype "%s"' % filetype)
281280

282-
# @property
283-
# def gradient_consistency(self):
284-
# """
285-
# Check that b gradients and signal variation in dwi image are consistent with one another.
286-
# """
287-
# if (Path(self._dwi_file).exists() is True) and (self._image_consistency is True):
288-
# return image_gradient_consistency_check(str(self._dwi_file), self.bvecs, self.bvals,
289-
# b0_threshold=self._b0_thres)
290-
# else:
291-
# raise FileNotFoundError(
292-
# "DWI file not found, and is required for checking image-gradient consistency.")
293-
294281

295282
def normalize_gradients(
296283
bvecs,
@@ -498,73 +485,32 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2):
498485
return rotated_bvecs
499486

500487

501-
def image_gradient_consistency_check(dwi_file, bvecs, bvals, b0_threshold=B0_THRESHOLD):
488+
def rasb_dwi_length_check(dwi_file, rasb_file):
489+
"""Check the number of encoding vectors and number of orientations in the DWI file."""
490+
return nb.load(dwi_file).shape[-1] == len(np.loadtxt(rasb_file, skiprows=1))
491+
492+
493+
def b0mask_from_data(dwi_file, mask_file, z_thres=3.0):
502494
"""
503-
Check that gradient encoding patterns found in the b-values correspond to those
504-
found in the signal intensity variance across volumes of the dwi image.
495+
Evaluate B0 locations relative to mean signal variation.
496+
497+
Standardizes (z-score) the average DWI signal within mask and threshold.
498+
This is a data-driven way of estimating which volumes in the DWI dataset are
499+
really encoding *low-b* acquisitions.
505500
506501
Parameters
507502
----------
508-
dwi_file : str
509-
Optionally provide a file path to the diffusion-weighted image series to which the
510-
bvecs/bvals should correspond.
511-
bvecs : m x n 2d array
512-
B-vectors array.
513-
bvals : 1d array
514-
B-values float array.
515-
b0_threshold : float
516-
Gradient threshold below which volumes and vectors are considered B0's.
503+
dwi_file : :obj:`str`
504+
File path to the diffusion-weighted image series.
505+
mask_file : :obj:`str`
506+
File path to a mask corresponding to the DWI file.
507+
z_thres : :obj:`float`
508+
The z-value to consider a volume as a *low-b* orientation.
509+
517510
"""
518-
import nibabel as nib
519-
from dipy.core.gradients import gradient_table_from_bvals_bvecs
520-
from sklearn.cluster import MeanShift, estimate_bandwidth
521-
522-
# Build gradient table object
523-
gtab = gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=b0_threshold)
524-
525-
# Check that number of image volumes corresponds to the number of gradient encodings.
526-
volume_num = nib.load(dwi_file).shape[-1]
527-
if not len(bvals) == volume_num:
528-
raise Exception("Expected %d total image samples but found %d total gradient encoding values", volume_num,
529-
len(bvals))
530-
# Evaluate B0 locations relative to mean signal variation.
531-
data = np.array(nib.load(dwi_file).dataobj)
532-
signal_means = []
533-
for vol in range(data.shape[-1]):
534-
signal_means.append(np.mean(data[:, :, :, vol]))
535-
signal_b0_indices = np.where(signal_means > np.mean(signal_means)+3*np.std(signal_means))[0]
536-
537-
# Check B0 locations first
538-
if not np.allclose(np.where(gtab.b0s_mask == True)[0], signal_b0_indices):
539-
raise UserWarning('B0 indices in vectors do not correspond to relative high-signal contrast volumes '
540-
'detected in dwi image.')
541-
542-
# Next check number of unique b encodings (i.e. shells) and their indices
543-
X = np.array(signal_means).reshape(-1, 1)
544-
ms = MeanShift(bandwidth=estimate_bandwidth(X, quantile=0.1), bin_seeding=True)
545-
ms.fit(X)
546-
labs, idx = np.unique(ms.labels_, return_index=True)
547-
ix_img = []
548-
i = -1
549-
for val in range(len(ms.labels_)):
550-
if val in np.sort(idx[::-1]):
551-
i = i + 1
552-
ix_img.append(labs[i])
553-
ix_img = np.array(ix_img)
554-
555-
ix_vec = []
556-
i = -1
557-
for val in range(len(bvals)):
558-
if bvals[val] != bvals[val-1]:
559-
i = i + 1
560-
ix_vec.append(i)
561-
ix_vec = np.array(ix_vec)
562-
563-
if len(ms.cluster_centers_) != len(np.unique(bvals)):
564-
raise UserWarning('Number of unique b-values does not correspond to number of unique signal gradient '
565-
'encoding intensities in the dwi image.')
566-
567-
if np.all(np.isclose(ix_img, ix_vec)) is True:
568-
raise UserWarning('Positions of b-value B0\'s and shell(s) do not correspond to signal intensity '
569-
'fluctuation patterns in the dwi image.')
570-
return
511+
data = np.asanyarray(nb.load(dwi_file).dataobj)
512+
mask = np.asanyarray(nb.load(mask_file).dataobj) > 0.5
513+
signal_means = np.median(data[mask, np.newaxis], axis=0)
514+
zscored_means = signal_means - signal_means.mean()
515+
zscored_means /= zscored_means.std()
516+
return zscored_means > z_thres

0 commit comments

Comments
 (0)