Skip to content

Commit 275fd21

Browse files
committed
ENH: Validate DWI data objects' attributes at instantiation
Validate DWI data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities.
1 parent d5fc6e1 commit 275fd21

File tree

2 files changed

+359
-53
lines changed

2 files changed

+359
-53
lines changed

src/nifreeze/data/dmri.py

Lines changed: 207 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,23 @@
3434
import numpy as np
3535
import numpy.typing as npt
3636
from nibabel.spatialimages import SpatialImage
37+
from numpy.typing import ArrayLike
3738
from typing_extensions import Self
3839

39-
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
40+
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_dim_size, _has_ndim
4041
from nifreeze.utils.ndimage import get_data, load_api
4142

43+
GRADIENT_ABSENCE_ERROR_MSG = "DWI 'gradients' may not be None"
44+
"""DWI initialization gradient absence error message."""
45+
46+
GRADIENT_OBJECT_ERROR_MSG = "DWI 'gradients' must be a numpy array."
47+
"""DWI initialization gradient object error message."""
48+
49+
GRADIENT_COUNT_MISMATCH_ERROR_MSG = (
50+
"DWI gradients count ({n_gradients}) does not match dataset volumes ({data_vols})."
51+
)
52+
"""DWI initialization gradient count mismatch error message."""
53+
4254
DEFAULT_CLIP_PERCENTILE = 75
4355
"""Upper percentile threshold for intensity clipping."""
4456

@@ -64,6 +76,150 @@
6476
"""Minimum number of nonzero b-values in a DWI dataset."""
6577

6678

79+
def _check_gradient_shape(value: np.ndarray) -> None:
80+
"""Strictly validate a gradients ndarray.
81+
82+
Validates that ``value`` is a correctly-shaped NumPy array representing
83+
gradients. It performs a sequence of checks and raises :exc:`TypeError` or
84+
:exc:`ValueError` with intentionally explicit messages suitable for use by
85+
higher-level validators.
86+
87+
The following conditions raise an exception:
88+
- ``value`` is not a 2D :obj:`~numpy.ndarray`.
89+
- ``value`` does not have 4 columns.
90+
91+
Parameters
92+
----------
93+
value : :obj:`~numpy.ndarray`
94+
The candidate gradients array.
95+
96+
Raises
97+
------
98+
:exc:`ValueError`
99+
If ``value`` fails any of the checks described above.
100+
101+
Examples
102+
--------
103+
>>> _check_gradient_shape(np.zeros((10, 3))) # valid: does not raise
104+
>>> _check_gradient_shape(np.asarray([[1, 2, 3], [1, 2]]) # raises ValueError
105+
>>> _check_gradient_shape(np.zeros((5,))) # raises ValueError
106+
>>> _check_gradient_shape(np.zeros((2, 6))) # raises ValueError
107+
"""
108+
109+
if value is None:
110+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
111+
112+
# Reject ragged/object-dtype arrays explicitly
113+
if value.dtype == object:
114+
raise TypeError(GRADIENT_OBJECT_ERROR_MSG)
115+
116+
if not _has_ndim(value, 2):
117+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
118+
119+
if not _has_dim_size(value, 4):
120+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
121+
122+
123+
def _gradients_converter(value: ArrayLike) -> np.ndarray:
124+
"""Permissive gradient converter.
125+
126+
Behavior:
127+
- Converts the incoming ``value`` to a float NumPy array.
128+
- Ensures the result is 2-D and that one dimension equals 4.
129+
- If a 2-D array has ``shape[0] == 4`` and ``shape[1] != 4``, it will be
130+
transposed so the returned array has ``shape[1] == 4``.
131+
- For 1-D inputs of length 4, returns an array shaped ``(1, 4)``.
132+
- Raises exc:`TypeError` for conversion failures and exc:`ValueError` for
133+
shape violations.
134+
135+
Parameters
136+
----------
137+
value : :obj:`ArrayLike`
138+
Input to convert to a :obj:`~numpy.ndarray` of floats.
139+
140+
Returns
141+
-------
142+
:obj:`~numpy.ndarray`
143+
A 2-D float array with ``shape[1] == 4``.
144+
145+
Raises
146+
------
147+
exc:`TypeError`
148+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
149+
exc:`ValueError`
150+
If the converted array is not 2-D (after the 1-D -> 2-D promotion)
151+
or does not have a dimension of size 4 such that the returned array
152+
can be shaped with ``shape[1] == 4``.
153+
154+
Examples
155+
--------
156+
>>> _gradients_converter([0, 0, 0, 1]).shape
157+
(1, 4)
158+
>>> _gradients_converter(np.zeros((10, 4))).shape
159+
(10, 4)
160+
>>> _gradients_converter(np.zeros((4, 10))).shape
161+
(10, 4) # transposed so shape[1] == 4
162+
"""
163+
164+
if value is None:
165+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
166+
167+
# Convert to ndarray
168+
if isinstance(value, np.ndarray):
169+
arr = value.astype(float, copy=False)
170+
else:
171+
try:
172+
arr = np.asarray(value, dtype=float)
173+
except (TypeError, ValueError) as exc:
174+
# Conversion failed (e.g. nested ragged objects, non-numeric)
175+
raise TypeError(GRADIENT_OBJECT_ERROR_MSG) from exc
176+
177+
_check_gradient_shape(arr)
178+
179+
if arr.shape[1] == 4:
180+
pass
181+
else:
182+
arr = arr.T
183+
184+
# ToDo
185+
# Call gradient normalization
186+
return arr
187+
188+
189+
def _gradients_validator(inst: DWI, attr: attrs.Attribute, value: Any) -> None:
190+
"""Strict validator for use in attribute validation (e.g. attrs / validators).
191+
192+
Enforces that ``value`` is a NumPy array and has the expected 2-D shape
193+
with 4 columns (``shape[1] == 4``).
194+
195+
This function is intended for use as an attrs-style validator.
196+
197+
Raises
198+
------
199+
exc:`TypeError`
200+
If ``value`` is not a :obj:`~numpy.ndarray`.
201+
exc:`ValueError``
202+
If ``value`` is not 2-D or its shape does not have 4 columns.
203+
204+
Parameters
205+
----------
206+
inst : :obj:`:obj:`~nifreeze.data.dmri.DWI`
207+
The instance being validated (unused, present for validator signature).
208+
attr : :obj:`~attrs.Attribute`
209+
The attribute being validated (unused, present for validator signature).
210+
value : :obj:`Any`
211+
The value to validate.
212+
"""
213+
214+
if value is None:
215+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
216+
217+
if not isinstance(value, np.ndarray):
218+
raise TypeError(GRADIENT_OBJECT_ERROR_MSG)
219+
220+
_check_gradient_shape(value)
221+
222+
67223
@attrs.define(slots=True)
68224
class DWI(BaseDataset[np.ndarray]):
69225
"""Data representation structure for dMRI data."""
@@ -72,41 +228,44 @@ class DWI(BaseDataset[np.ndarray]):
72228
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
73229
)
74230
"""A *b=0* reference map, preferably obtained by some smart averaging."""
75-
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
231+
gradients: np.ndarray = attrs.field(
232+
default=None,
233+
repr=_data_repr,
234+
eq=attrs.cmp_using(eq=_cmp),
235+
validator=_gradients_validator,
236+
converter=_gradients_converter,
237+
)
76238
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
77239
eddy_xfms: list = attrs.field(default=None)
78240
"""List of transforms to correct for estimated eddy current distortions."""
79241

80242
def __attrs_post_init__(self) -> None:
81-
self._normalize_gradients()
82-
83-
def _normalize_gradients(self) -> None:
84-
if self.gradients is None:
85-
return
243+
"""Enforce basic consistency of required dMRI fields at instantiation
244+
time.
86245
87-
gradients = np.asarray(self.gradients)
88-
if gradients.ndim != 2:
89-
raise ValueError("Gradient table must be a 2D array")
246+
Specifically, the number of gradient directions must match the last
247+
dimension of the data (number of volumes).
248+
"""
90249

250+
# If the data object exists and has a time/volume axis, ensure sizes
251+
# match.
91252
n_volumes = None
92-
if self.dataobj is not None:
93-
try:
94-
n_volumes = self.dataobj.shape[-1]
95-
except Exception: # pragma: no cover - extremely defensive
96-
n_volumes = None
97-
98-
if n_volumes is not None and gradients.shape[0] != n_volumes:
99-
if gradients.shape[1] == n_volumes:
100-
gradients = gradients.T
101-
else:
253+
if getattr(self, "dataobj", None) is not None:
254+
shape = getattr(self.dataobj, "shape", None)
255+
if isinstance(shape, (tuple, list)) and len(shape) >= 1:
256+
try:
257+
n_volumes = int(shape[-1])
258+
except (TypeError, ValueError):
259+
n_volumes = None
260+
261+
if n_volumes is not None:
262+
n_gradients = self.gradients.shape[1]
263+
if n_gradients != n_volumes:
102264
raise ValueError(
103-
"Gradient table shape does not match the number of diffusion volumes: "
104-
f"expected {n_volumes} rows, found {gradients.shape[0]}"
265+
GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(
266+
n_gradients=n_gradients, data_vols=n_volumes
267+
)
105268
)
106-
elif n_volumes is None and gradients.shape[1] > gradients.shape[0]:
107-
gradients = gradients.T
108-
109-
self.gradients = gradients
110269

111270
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
112271
return (self.gradients[idx, ...],)
@@ -315,6 +474,20 @@ def to_nifti(
315474
return nii
316475

317476

477+
def _compose_gradients(bvec_file: Path | str, bval_file: Path | str):
478+
bvecs = np.loadtxt(bvec_file, dtype="float32")
479+
if bvecs.ndim == 1:
480+
bvecs = bvecs[np.newaxis, :]
481+
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
482+
bvecs = bvecs.T
483+
484+
bvals = np.loadtxt(bval_file, dtype="float32")
485+
if bvals.ndim > 1:
486+
bvals = np.squeeze(bvals)
487+
488+
return np.column_stack((bvecs, bvals))
489+
490+
318491
def from_nii(
319492
filename: Path | str,
320493
brainmask_file: Path | str | None = None,
@@ -389,35 +562,14 @@ def from_nii(
389562
stacklevel=2,
390563
)
391564
elif bvec_file and bval_file:
392-
bvecs = np.loadtxt(bvec_file, dtype="float32")
393-
if bvecs.ndim == 1:
394-
bvecs = bvecs[np.newaxis, :]
395-
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
396-
bvecs = bvecs.T
397-
398-
bvals = np.loadtxt(bval_file, dtype="float32")
399-
if bvals.ndim > 1:
400-
bvals = np.squeeze(bvals)
401-
grad = np.column_stack((bvecs, bvals))
565+
grad = _compose_gradients(bvec_file, bval_file)
402566
else:
403567
raise RuntimeError(
404568
"No gradient data provided. "
405569
"Please specify either a gradients_file or (bvec_file & bval_file)."
406570
)
407571

408-
if grad.ndim == 1:
409-
grad = grad[np.newaxis, :]
410-
411-
if grad.shape[1] < 2:
412-
raise ValueError("Gradient table must have at least two columns (direction + b-value).")
413-
414-
if grad.shape[1] != 4:
415-
if grad.shape[0] == 4:
416-
grad = grad.T
417-
else:
418-
raise ValueError(
419-
"Gradient table must have four columns (3 direction components and one b-value)."
420-
)
572+
grad = _gradients_converter(grad)
421573

422574
# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
423575
# as "DW volumes" if the user wants to store only the high-b volumes here
@@ -426,11 +578,14 @@ def from_nii(
426578
dwi_obj = DWI(
427579
dataobj=fulldata[..., gradmsk],
428580
affine=img.affine,
429-
# We'll assign the filtered gradients below.
581+
gradients=grad[
582+
gradmsk, :
583+
], # ToDo Duplicate call to _gradients_converter but cannot do better I think
430584
)
431585

432-
dwi_obj.gradients = grad[gradmsk, :]
433-
dwi_obj._normalize_gradients()
586+
# removing gradients = np.asarray(self.gradients) from _normalize_gradients:
587+
# the annotation does not suggest anything other than arrays: if we want a list of lists, we should type hint that.
588+
# The converter duplicates the checks, and we could skip it in the signature, but I think it is wise to keep it
434589

435590
# 4) b=0 volume (bzero)
436591
# If the user provided a b0_file, load it

0 commit comments

Comments
 (0)