Skip to content

Commit ee02d74

Browse files
authored
Merge pull request #2 from oesteban/rf/dwi-data-pr
rf: overhaul suggested by #327's title
2 parents c250888 + 16d1b6a commit ee02d74

File tree

2 files changed

+141
-93
lines changed

2 files changed

+141
-93
lines changed

src/nifreeze/data/dmri.py

Lines changed: 113 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,20 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66-
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = "Gradient table shape does not match the number of diffusion volumes: expected {n_volumes} rows, found {n_gradients}."
66+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = """\
67+
Gradient table shape does not match the number of diffusion volumes: \
68+
expected {n_volumes} rows, found {n_gradients}."""
6769
"""dMRI volume count vs. gradient count mismatch error message."""
6870

69-
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = "Both a gradients table file and b-vec/val files are defined; ignoring b-vec/val files in favor of the gradients_file."
71+
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = """\
72+
Both a gradients table file and b-vec/val files are defined; \
73+
ignoring b-vec/val files in favor of the gradients_file."""
7074
""""dMRI gradient file priority warning message."""
7175

7276
GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
7377
"""dMRI gradient dimensionality error message."""
7478

75-
GRADIENT_DATA_MISSING_ERROR = (
76-
"No gradient data provided. Please specify either a gradients_file or (bvec_file & bval_file)."
77-
)
79+
GRADIENT_DATA_MISSING_ERROR = "No gradient data provided."
7880
"""dMRI missing gradient data error message."""
7981

8082
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
@@ -83,51 +85,115 @@
8385
"""dMRI gradient expected columns error message."""
8486

8587

88+
def format_gradients(
89+
value: npt.ArrayLike | None,
90+
) -> np.ndarray | None:
91+
"""
92+
Validate and orient gradient tables to a consistent shape.
93+
94+
Examples
95+
--------
96+
Passing an already well-formed table returns the data unchanged::
97+
98+
>>> format_gradients(
99+
... [
100+
... [1, 0, 0, 0],
101+
... [0, 1, 0, 1000],
102+
... [0, 0, 1, 2000],
103+
... [0, 0, 0, 0],
104+
... [0, 0, 0, 1000],
105+
... ]
106+
... )
107+
array([[ 1, 0, 0, 0],
108+
[ 0, 1, 0, 1000],
109+
[ 0, 0, 1, 2000],
110+
[ 0, 0, 0, 0],
111+
[ 0, 0, 0, 1000]])
112+
113+
Column-major inputs are automatically transposed when an expected
114+
number of diffusion volumes is provided::
115+
116+
>>> format_gradients(
117+
... [[1, 0], [0, 1], [0, 0], [1000, 2000]],
118+
... )
119+
array([[ 1, 0, 0, 1000],
120+
[ 0, 1, 0, 2000]])
121+
122+
Gradient tables must always have two dimensions::
123+
124+
>>> format_gradients([0, 1, 0, 1000])
125+
Traceback (most recent call last):
126+
...
127+
ValueError: Gradient table must be a 2D array
128+
129+
"""
130+
131+
formatted = np.asarray(value)
132+
if formatted.ndim != 2:
133+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
134+
135+
# Transpose if column-major
136+
return formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
137+
138+
139+
def validate_gradients(
140+
instance: DWI,
141+
attribute: attrs.Attribute,
142+
value: npt.NDArray[np.floating],
143+
) -> None:
144+
"""Ensure row-major convention for gradient table."""
145+
if value.shape[1] != 4:
146+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
147+
148+
86149
@attrs.define(slots=True)
87150
class DWI(BaseDataset[np.ndarray]):
88151
"""Data representation structure for dMRI data."""
89152

153+
gradients: np.ndarray = attrs.field(
154+
default=None,
155+
repr=_data_repr,
156+
eq=attrs.cmp_using(eq=_cmp),
157+
converter=format_gradients,
158+
validator=validate_gradients,
159+
)
160+
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
90161
bzero: np.ndarray | None = attrs.field(
91162
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
92163
)
93-
"""A *b=0* reference map, preferably obtained by some smart averaging."""
94-
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
95-
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
164+
"""A *b=0* reference map, computed automatically when low-b frames are present."""
96165
eddy_xfms: list = attrs.field(default=None)
97166
"""List of transforms to correct for estimated eddy current distortions."""
98167

99168
def __attrs_post_init__(self) -> None:
100-
self._normalize_gradients()
101-
102-
def _normalize_gradients(self) -> None:
103169
if self.gradients is None:
104-
return
105-
106-
gradients = np.asarray(self.gradients)
107-
if gradients.ndim != 2:
108-
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
109-
if gradients.shape[1] == 4:
110-
pass
111-
elif gradients.shape[0] == 4:
112-
gradients = gradients.T
113-
else:
114-
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
115-
116-
n_volumes = None
117-
if self.dataobj is not None:
118-
try:
119-
n_volumes = self.dataobj.shape[-1]
120-
except Exception: # pragma: no cover - extremely defensive
121-
n_volumes = None
170+
raise ValueError(GRADIENT_DATA_MISSING_ERROR)
122171

123-
if n_volumes is not None and gradients.shape[0] != n_volumes:
172+
if self.dataobj.shape[-1] != self.gradients.shape[0]:
124173
raise ValueError(
125174
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR.format(
126-
n_volumes=n_volumes, n_gradients=gradients.shape[0]
175+
n_volumes=self.dataobj.shape[-1],
176+
n_gradients=self.gradients.shape[0],
127177
)
128178
)
129179

130-
self.gradients = gradients
180+
b0_mask = self.gradients[:, -1] <= DEFAULT_LOWB_THRESHOLD
181+
b0_num = np.sum(b0_mask)
182+
183+
if b0_num > 0 and self.bzero is None:
184+
bzeros = self.dataobj[..., b0_mask]
185+
self.bzero = bzeros if bzeros.ndim == 3 else np.median(bzeros, axis=-1)
186+
187+
if b0_num > 0:
188+
# Remove b0 volumes from dataobj and gradients
189+
self.gradients = self.gradients[~b0_mask, :]
190+
self.dataobj = self.dataobj[..., ~b0_mask]
191+
192+
if self.gradients.shape[0] < DTI_MIN_ORIENTATIONS:
193+
raise ValueError(
194+
f"DWI datasets must have at least {DTI_MIN_ORIENTATIONS} diffusion-weighted "
195+
f"orientations; found {self.dataobj.shape[-1]}."
196+
)
131197

132198
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
133199
return (self.gradients[idx, ...],)
@@ -339,12 +405,10 @@ def to_nifti(
339405
def from_nii(
340406
filename: Path | str,
341407
brainmask_file: Path | str | None = None,
342-
motion_file: Path | str | None = None,
343408
gradients_file: Path | str | None = None,
344409
bvec_file: Path | str | None = None,
345410
bval_file: Path | str | None = None,
346411
b0_file: Path | str | None = None,
347-
b0_thres: float = DEFAULT_LOWB_THRESHOLD,
348412
) -> DWI:
349413
"""
350414
Load DWI data from NIfTI and construct a DWI object.
@@ -359,8 +423,6 @@ def from_nii(
359423
brainmask_file : :obj:`os.pathlike`, optional
360424
A brainmask NIfTI file. If provided, will be loaded and
361425
stored in the returned dataset.
362-
motion_file : :obj:`os.pathlike`, optional
363-
A file containing head motion affine matrices (linear)
364426
gradients_file : :obj:`os.pathlike`, optional
365427
A text file containing the gradients table, shape (N, C) where the last column
366428
stores the b-values. If provided following the column-major convention(C, N),
@@ -373,9 +435,6 @@ def from_nii(
373435
b0_file : :obj:`os.pathlike`, optional
374436
A NIfTI file containing a b=0 volume (possibly averaged or reference).
375437
If not provided, and the data contains at least one b=0 volume, one will be computed.
376-
b0_thres : float, optional
377-
Threshold for determining which volumes are considered DWI vs. b=0
378-
if you combine them in the same file.
379438
380439
Returns
381440
-------
@@ -390,10 +449,6 @@ def from_nii(
390449
``bvec_file`` + ``bval_file``).
391450
392451
"""
393-
394-
if motion_file:
395-
raise NotImplementedError
396-
397452
filename = Path(filename)
398453

399454
# 1) Load a NIfTI
@@ -405,18 +460,8 @@ def from_nii(
405460
grad = np.loadtxt(gradients_file, dtype="float32")
406461
if bvec_file and bval_file:
407462
warn(GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG, stacklevel=2)
408-
if grad.ndim != 2:
409-
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
410-
if grad.shape[1] == 4:
411-
pass
412-
elif grad.shape[0] == 4:
413-
grad = grad.T
414-
else:
415-
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
416463
elif bvec_file and bval_file:
417464
bvecs = np.loadtxt(bvec_file, dtype="float32")
418-
if bvecs.ndim == 1:
419-
bvecs = bvecs[np.newaxis, :]
420465
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
421466
bvecs = bvecs.T
422467

@@ -427,40 +472,26 @@ def from_nii(
427472
else:
428473
raise RuntimeError(GRADIENT_DATA_MISSING_ERROR)
429474

430-
# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
431-
# as "DW volumes" if the user wants to store only the high-b volumes here
432-
gradmsk = grad[:, -1] > b0_thres
433-
434-
dwi_obj = DWI(
435-
dataobj=fulldata[..., gradmsk],
436-
affine=img.affine,
437-
# We'll assign the filtered gradients below.
438-
)
439-
440-
dwi_obj.gradients = grad[gradmsk, :]
441-
dwi_obj._normalize_gradients()
442-
443-
# 4) b=0 volume (bzero)
444-
# If the user provided a b0_file, load it
475+
# 3) Read b-zero volume if provided
476+
b0_data = None
445477
if b0_file:
446478
b0img = load_api(b0_file, SpatialImage)
447-
b0vol = np.asanyarray(b0img.dataobj)
448-
# We'll assume your DWI class has a bzero: np.ndarray | None attribute
449-
dwi_obj.bzero = b0vol
450-
# Otherwise, if any volumes remain outside gradmsk, compute a median B0:
451-
elif np.any(~gradmsk):
452-
# The b=0 volumes are those that did NOT pass b0_thres
453-
b0_volumes = fulldata[..., ~gradmsk]
454-
# A simple approach is to take the median across that last dimension
455-
# Note that axis=3 is valid only if your data is 4D (x, y, z, volumes).
456-
dwi_obj.bzero = np.median(b0_volumes, axis=3)
457-
458-
# 5) If a brainmask_file was provided, load it
479+
b0_data = np.asanyarray(b0img.dataobj)
480+
481+
# 4) If a brainmask_file was provided, load it
482+
brainmask_data = None
459483
if brainmask_file:
460484
mask_img = load_api(brainmask_file, SpatialImage)
461-
dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
485+
brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool)
462486

463-
return dwi_obj
487+
# 5) Create and return the DWI instance.
488+
return DWI(
489+
dataobj=fulldata,
490+
affine=img.affine,
491+
gradients=grad,
492+
bzero=b0_data,
493+
brainmask=brainmask_data,
494+
)
464495

465496

466497
def find_shelling_scheme(

test/test_data_dmri.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,6 @@ def test_main(datadir):
8888
assert isinstance(load(input_file), DWI)
8989

9090

91-
def test_motion_file_not_implemented():
92-
with pytest.raises(NotImplementedError):
93-
from_nii("dmri.nii.gz", motion_file="motion.x5")
94-
95-
9691
@pytest.mark.random_gtab_data(10, (1000, 2000), 2)
9792
@pytest.mark.random_dwi_data(50, (34, 36, 24), True)
9893
@pytest.mark.parametrize("row_major_gradients", (False, True))
@@ -169,7 +164,7 @@ def test_dwi_instantiation_gradients_ndim_error(
169164
[(1, 0), (2, 0), (2, 1), (0, 1), (0, 2), (1, 2)],
170165
)
171166
def test_gradient_instantiation_dwi_vol_mismatch_error(
172-
tmp_path, setup_random_dwi_data, additional_volume_count, additional_gradient_count
167+
setup_random_dwi_data, additional_volume_count, additional_gradient_count
173168
):
174169
(
175170
dwi_dataobj,
@@ -189,6 +184,25 @@ def test_gradient_instantiation_dwi_vol_mismatch_error(
189184
additional_gradients = np.tile(gradients[-1:, :], (additional_gradient_count, 1))
190185
gradients = np.concatenate((gradients, additional_gradients), axis=0)
191186

187+
# Test with b0s present
188+
n_volumes = dwi_dataobj.shape[-1]
189+
with pytest.raises(
190+
ValueError,
191+
match=GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR.format(
192+
n_volumes=n_volumes, n_gradients=gradients.shape[0]
193+
),
194+
):
195+
DWI(
196+
dataobj=dwi_dataobj,
197+
affine=affine,
198+
brainmask=brainmask_dataobj,
199+
bzero=b0_dataobj,
200+
gradients=gradients,
201+
)
202+
203+
# Test without b0s present
204+
dwi_dataobj = dwi_dataobj[..., 2:]
205+
gradients = gradients[2:, :]
192206
n_volumes = dwi_dataobj.shape[-1]
193207
with pytest.raises(
194208
ValueError,
@@ -296,7 +310,7 @@ def test_load_gradients_bval_bvec_warn(tmp_path, setup_random_dwi_data):
296310
brainmask_dataobj,
297311
b0_dataobj,
298312
gradients,
299-
b0_thres,
313+
_,
300314
) = setup_random_dwi_data
301315

302316
dwi, _, _ = _dwi_data_to_nifti(
@@ -309,6 +323,9 @@ def test_load_gradients_bval_bvec_warn(tmp_path, setup_random_dwi_data):
309323
dwi_fname = tmp_path / "dwi.nii.gz"
310324
nb.save(dwi, dwi_fname)
311325

326+
b0_fname = tmp_path / "b0.nii.gz"
327+
nb.Nifti1Image(b0_dataobj, np.eye(4), None).to_filename(b0_fname)
328+
312329
grads_fname = tmp_path / "grads.txt"
313330
np.savetxt(grads_fname, gradients, fmt="%.6f")
314331

@@ -326,7 +343,7 @@ def test_load_gradients_bval_bvec_warn(tmp_path, setup_random_dwi_data):
326343
gradients_file=grads_fname,
327344
bvec_file=bvec_fname,
328345
bval_file=bval_fname,
329-
b0_thres=b0_thres,
346+
b0_file=b0_fname,
330347
)
331348

332349

@@ -359,7 +376,7 @@ def test_load_gradients(tmp_path, setup_random_dwi_data, row_major_gradients):
359376
grads_fname = tmp_path / "grads.txt"
360377
np.savetxt(grads_fname, gradients, fmt="%.6f")
361378

362-
dwi = from_nii(dwi_fname, gradients_file=grads_fname, b0_thres=b0_thres)
379+
dwi = from_nii(dwi_fname, gradients_file=grads_fname)
363380
if not row_major_gradients:
364381
gradmask = gradients.T[:, -1] > b0_thres
365382
else:
@@ -419,7 +436,7 @@ def test_load_bvecs_bvals(tmp_path, setup_random_dwi_data, transpose_bvals, tran
419436
np.savetxt(bvec_fname, bvecs, fmt="%.6f")
420437
np.savetxt(bval_fname, bvals, fmt="%.6f")
421438

422-
dwi = from_nii(dwi_fname, bvec_file=bvec_fname, bval_file=bval_fname, b0_thres=b0_thres)
439+
dwi = from_nii(dwi_fname, bvec_file=bvec_fname, bval_file=bval_fname)
423440
gradmask = gradients[:, -1] > b0_thres
424441

425442
expected_nonzero_grads = gradients[gradmask]
@@ -454,6 +471,7 @@ def test_load_gradients_missing(tmp_path, setup_random_dwi_data):
454471
from_nii(dwi_fname)
455472

456473

474+
@pytest.mark.skip(reason="to_nifti takes absurdly long")
457475
@pytest.mark.parametrize("insert_b0", (False, True))
458476
@pytest.mark.parametrize("rotate_bvecs", (False, True))
459477
def test_load(datadir, tmp_path, insert_b0, rotate_bvecs): # noqa: C901
@@ -608,7 +626,6 @@ def test_equality_operator(tmp_path, setup_random_dwi_data):
608626
gradients_file=gradients_fname,
609627
b0_file=b0_fname,
610628
brainmask_file=brainmask_fname,
611-
b0_thres=b0_thres,
612629
)
613630
hdf5_filename = tmp_path / "test_dwi.h5"
614631
dwi_obj.to_filename(hdf5_filename)

0 commit comments

Comments
 (0)