Skip to content

Commit 98bb3b5

Browse files
oestebanjhlegarreta
authored andcommitted
rf: overhaul suggested by #327's title
This patch - moves the responsibility of formatting and validating the gradients into the attrs' infrastructure; - removes complexity from ``from_nifti()`` that was completely unnecessary (e.g., motion file can be set after creating the DWI object). - updates tests - skips one test that is unreasonably slow, we need to look into what's going on with ``to_nifti()``.
1 parent 0b6faa0 commit 98bb3b5

File tree

6 files changed

+152
-122
lines changed

6 files changed

+152
-122
lines changed

src/nifreeze/data/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
def load(
3333
filename: Path | str,
3434
brainmask_file: Path | str | None = None,
35-
motion_file: Path | str | None = None,
3635
**kwargs,
3736
) -> BaseDataset | DWI | PET:
3837
"""
@@ -45,8 +44,6 @@ def load(
4544
brainmask_file : :obj:`os.pathlike`, optional
4645
A brainmask NIfTI file. If provided, will be loaded and
4746
stored in the returned dataset.
48-
motion_file : :obj:`os.pathlike`
49-
A file containing head motion affine matrices (linear).
5047
5148
Returns
5249
-------
@@ -67,9 +64,6 @@ def load(
6764

6865
from nifreeze.utils.ndimage import load_api
6966

70-
if motion_file:
71-
raise NotImplementedError
72-
7367
filename = Path(filename)
7468
if filename.name.endswith(NFDH5_EXT):
7569
for dataclass in (BaseDataset, PET, DWI):
@@ -81,15 +75,11 @@ def load(
8175
if "gradients_file" in kwargs or "bvec_file" in kwargs:
8276
from nifreeze.data.dmri import from_nii as dmri_from_nii
8377

84-
return dmri_from_nii(
85-
filename, brainmask_file=brainmask_file, motion_file=motion_file, **kwargs
86-
)
78+
return dmri_from_nii(filename, brainmask_file=brainmask_file, **kwargs)
8779
elif "frame_time" in kwargs or "frame_duration" in kwargs:
8880
from nifreeze.data.pet import from_nii as pet_from_nii
8981

90-
return pet_from_nii(
91-
filename, brainmask_file=brainmask_file, motion_file=motion_file, **kwargs
92-
)
82+
return pet_from_nii(filename, brainmask_file=brainmask_file, **kwargs)
9383

9484
img = load_api(filename, SpatialImage)
9585
retval: BaseDataset = BaseDataset(dataobj=np.asanyarray(img.dataobj), affine=img.affine)

src/nifreeze/data/dmri.py

Lines changed: 113 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,20 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66-
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = "Gradient table shape does not match the number of diffusion volumes: expected {n_volumes} rows, found {n_gradients}."
66+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = """\
67+
Gradient table shape does not match the number of diffusion volumes: \
68+
expected {n_volumes} rows, found {n_gradients}."""
6769
"""dMRI volume count vs. gradient count mismatch error message."""
6870

69-
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = "Both a gradients table file and b-vec/val files are defined; ignoring b-vec/val files in favor of the gradients_file."
71+
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = """\
72+
Both a gradients table file and b-vec/val files are defined; \
73+
ignoring b-vec/val files in favor of the gradients_file."""
7074
""""dMRI gradient file priority warning message."""
7175

7276
GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
7377
"""dMRI gradient dimensionality error message."""
7478

75-
GRADIENT_DATA_MISSING_ERROR = (
76-
"No gradient data provided. Please specify either a gradients_file or (bvec_file & bval_file)."
77-
)
79+
GRADIENT_DATA_MISSING_ERROR = "No gradient data provided."
7880
"""dMRI missing gradient data error message."""
7981

8082
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
@@ -83,51 +85,115 @@
8385
"""dMRI gradient expected columns error message."""
8486

8587

88+
def format_gradients(
89+
value: npt.ArrayLike | None,
90+
) -> np.ndarray | None:
91+
"""
92+
Validate and orient gradient tables to a consistent shape.
93+
94+
Examples
95+
--------
96+
Passing an already well-formed table returns the data unchanged::
97+
98+
>>> format_gradients(
99+
... [
100+
... [1, 0, 0, 0],
101+
... [0, 1, 0, 1000],
102+
... [0, 0, 1, 2000],
103+
... [0, 0, 0, 0],
104+
... [0, 0, 0, 1000],
105+
... ]
106+
... )
107+
array([[ 1, 0, 0, 0],
108+
[ 0, 1, 0, 1000],
109+
[ 0, 0, 1, 2000],
110+
[ 0, 0, 0, 0],
111+
[ 0, 0, 0, 1000]])
112+
113+
Column-major inputs are automatically transposed when an expected
114+
number of diffusion volumes is provided::
115+
116+
>>> format_gradients(
117+
... [[1, 0], [0, 1], [0, 0], [1000, 2000]],
118+
... )
119+
array([[ 1, 0, 0, 1000],
120+
[ 0, 1, 0, 2000]])
121+
122+
Gradient tables must always have two dimensions::
123+
124+
>>> format_gradients([0, 1, 0, 1000])
125+
Traceback (most recent call last):
126+
...
127+
ValueError: Gradient table must be a 2D array
128+
129+
"""
130+
131+
formatted = np.asarray(value)
132+
if formatted.ndim != 2:
133+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
134+
135+
# Transpose if column-major
136+
return formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
137+
138+
139+
def validate_gradients(
140+
instance: DWI,
141+
attribute: attrs.Attribute,
142+
value: npt.NDArray[np.floating],
143+
) -> None:
144+
"""Ensure row-major convention for gradient table."""
145+
if value.shape[1] != 4:
146+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
147+
148+
86149
@attrs.define(slots=True)
87150
class DWI(BaseDataset[np.ndarray]):
88151
"""Data representation structure for dMRI data."""
89152

153+
gradients: np.ndarray = attrs.field(
154+
default=None,
155+
repr=_data_repr,
156+
eq=attrs.cmp_using(eq=_cmp),
157+
converter=format_gradients,
158+
validator=validate_gradients,
159+
)
160+
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
90161
bzero: np.ndarray | None = attrs.field(
91162
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
92163
)
93-
"""A *b=0* reference map, preferably obtained by some smart averaging."""
94-
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
95-
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
164+
"""A *b=0* reference map, computed automatically when low-b frames are present."""
96165
eddy_xfms: list = attrs.field(default=None)
97166
"""List of transforms to correct for estimated eddy current distortions."""
98167

99168
def __attrs_post_init__(self) -> None:
100-
self._normalize_gradients()
101-
102-
def _normalize_gradients(self) -> None:
103169
if self.gradients is None:
104-
return
105-
106-
gradients = np.asarray(self.gradients)
107-
if gradients.ndim != 2:
108-
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
109-
if gradients.shape[1] == 4:
110-
pass
111-
elif gradients.shape[0] == 4:
112-
gradients = gradients.T
113-
else:
114-
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
115-
116-
n_volumes = None
117-
if self.dataobj is not None:
118-
try:
119-
n_volumes = self.dataobj.shape[-1]
120-
except Exception: # pragma: no cover - extremely defensive
121-
n_volumes = None
170+
raise ValueError(GRADIENT_DATA_MISSING_ERROR)
122171

123-
if n_volumes is not None and gradients.shape[0] != n_volumes:
172+
if self.dataobj.shape[-1] != self.gradients.shape[0]:
124173
raise ValueError(
125174
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR.format(
126-
n_volumes=n_volumes, n_gradients=gradients.shape[0]
175+
n_volumes=self.dataobj.shape[-1],
176+
n_gradients=self.gradients.shape[0],
127177
)
128178
)
129179

130-
self.gradients = gradients
180+
b0_mask = self.gradients[:, -1] <= DEFAULT_LOWB_THRESHOLD
181+
b0_num = np.sum(b0_mask)
182+
183+
if b0_num > 0 and self.bzero is None:
184+
bzeros = self.dataobj[..., b0_mask]
185+
self.bzero = bzeros if bzeros.ndim == 3 else np.median(bzeros, axis=-1)
186+
187+
if b0_num > 0:
188+
# Remove b0 volumes from dataobj and gradients
189+
self.gradients = self.gradients[~b0_mask, :]
190+
self.dataobj = self.dataobj[..., ~b0_mask]
191+
192+
if self.gradients.shape[0] < DTI_MIN_ORIENTATIONS:
193+
raise ValueError(
194+
f"DWI datasets must have at least {DTI_MIN_ORIENTATIONS} diffusion-weighted "
195+
f"orientations; found {self.dataobj.shape[-1]}."
196+
)
131197

132198
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
133199
return (self.gradients[idx, ...],)
@@ -339,12 +405,10 @@ def to_nifti(
339405
def from_nii(
340406
filename: Path | str,
341407
brainmask_file: Path | str | None = None,
342-
motion_file: Path | str | None = None,
343408
gradients_file: Path | str | None = None,
344409
bvec_file: Path | str | None = None,
345410
bval_file: Path | str | None = None,
346411
b0_file: Path | str | None = None,
347-
b0_thres: float = DEFAULT_LOWB_THRESHOLD,
348412
) -> DWI:
349413
"""
350414
Load DWI data from NIfTI and construct a DWI object.
@@ -359,8 +423,6 @@ def from_nii(
359423
brainmask_file : :obj:`os.pathlike`, optional
360424
A brainmask NIfTI file. If provided, will be loaded and
361425
stored in the returned dataset.
362-
motion_file : :obj:`os.pathlike`, optional
363-
A file containing head motion affine matrices (linear)
364426
gradients_file : :obj:`os.pathlike`, optional
365427
A text file containing the gradients table, shape (N, C) where the last column
366428
stores the b-values. If provided following the column-major convention(C, N),
@@ -373,9 +435,6 @@ def from_nii(
373435
b0_file : :obj:`os.pathlike`, optional
374436
A NIfTI file containing a b=0 volume (possibly averaged or reference).
375437
If not provided, and the data contains at least one b=0 volume, one will be computed.
376-
b0_thres : float, optional
377-
Threshold for determining which volumes are considered DWI vs. b=0
378-
if you combine them in the same file.
379438
380439
Returns
381440
-------
@@ -390,10 +449,6 @@ def from_nii(
390449
``bvec_file`` + ``bval_file``).
391450
392451
"""
393-
394-
if motion_file:
395-
raise NotImplementedError
396-
397452
filename = Path(filename)
398453

399454
# 1) Load a NIfTI
@@ -405,18 +460,8 @@ def from_nii(
405460
grad = np.loadtxt(gradients_file, dtype="float32")
406461
if bvec_file and bval_file:
407462
warn(GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG, stacklevel=2)
408-
if grad.ndim != 2:
409-
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
410-
if grad.shape[1] == 4:
411-
pass
412-
elif grad.shape[0] == 4:
413-
grad = grad.T
414-
else:
415-
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
416463
elif bvec_file and bval_file:
417464
bvecs = np.loadtxt(bvec_file, dtype="float32")
418-
if bvecs.ndim == 1:
419-
bvecs = bvecs[np.newaxis, :]
420465
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
421466
bvecs = bvecs.T
422467

@@ -427,40 +472,26 @@ def from_nii(
427472
else:
428473
raise RuntimeError(GRADIENT_DATA_MISSING_ERROR)
429474

430-
# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
431-
# as "DW volumes" if the user wants to store only the high-b volumes here
432-
gradmsk = grad[:, -1] > b0_thres
433-
434-
dwi_obj = DWI(
435-
dataobj=fulldata[..., gradmsk],
436-
affine=img.affine,
437-
# We'll assign the filtered gradients below.
438-
)
439-
440-
dwi_obj.gradients = grad[gradmsk, :]
441-
dwi_obj._normalize_gradients()
442-
443-
# 4) b=0 volume (bzero)
444-
# If the user provided a b0_file, load it
475+
# 3) Read b-zero volume if provided
476+
b0_data = None
445477
if b0_file:
446478
b0img = load_api(b0_file, SpatialImage)
447-
b0vol = np.asanyarray(b0img.dataobj)
448-
# We'll assume your DWI class has a bzero: np.ndarray | None attribute
449-
dwi_obj.bzero = b0vol
450-
# Otherwise, if any volumes remain outside gradmsk, compute a median B0:
451-
elif np.any(~gradmsk):
452-
# The b=0 volumes are those that did NOT pass b0_thres
453-
b0_volumes = fulldata[..., ~gradmsk]
454-
# A simple approach is to take the median across that last dimension
455-
# Note that axis=3 is valid only if your data is 4D (x, y, z, volumes).
456-
dwi_obj.bzero = np.median(b0_volumes, axis=3)
457-
458-
# 5) If a brainmask_file was provided, load it
479+
b0_data = np.asanyarray(b0img.dataobj)
480+
481+
# 4) If a brainmask_file was provided, load it
482+
brainmask_data = None
459483
if brainmask_file:
460484
mask_img = load_api(brainmask_file, SpatialImage)
461-
dwi_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
485+
brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool)
462486

463-
return dwi_obj
487+
# 5) Create and return the DWI instance.
488+
return DWI(
489+
dataobj=fulldata,
490+
affine=img.affine,
491+
gradients=grad,
492+
bzero=b0_data,
493+
brainmask=brainmask_data,
494+
)
464495

465496

466497
def find_shelling_scheme(

src/nifreeze/data/pet.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def from_nii(
221221
filename: Path | str,
222222
frame_time: np.ndarray | list[float],
223223
brainmask_file: Path | str | None = None,
224-
motion_file: Path | str | None = None,
225224
frame_duration: np.ndarray | list[float] | None = None,
226225
) -> PET:
227226
"""
@@ -236,8 +235,6 @@ def from_nii(
236235
brainmask_file : :obj:`os.pathlike`, optional
237236
A brainmask NIfTI file. If provided, will be loaded and
238237
stored in the returned dataset.
239-
motion_file : :obj:`os.pathlike`, optional
240-
A file containing head motion affine matrices (linear).
241238
frame_duration : :obj:`numpy.ndarray` or :obj:`list` of :obj:`float`, optional
242239
The duration of each frame.
243240
If ``None``, it is derived by the difference of consecutive frame times,
@@ -254,8 +251,6 @@ def from_nii(
254251
If ``frame_time`` is not provided (BIDS requires it).
255252
256253
"""
257-
if motion_file:
258-
raise NotImplementedError
259254

260255
filename = Path(filename)
261256
# Load from NIfTI

0 commit comments

Comments
 (0)