From 3ba23e8da653f021c07642e32847d075bcc6283c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Fri, 21 Nov 2025 08:05:11 -0500 Subject: [PATCH] ENH: Validate DWI data objects' attributes at instantiation Validate DWI data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities. --- src/nifreeze/data/dmri.py | 259 ++++++++++++++++++++++++++++++-------- test/test_data_dmri.py | 153 +++++++++++++++++++++- 2 files changed, 359 insertions(+), 53 deletions(-) diff --git a/src/nifreeze/data/dmri.py b/src/nifreeze/data/dmri.py index e7b51b499..2ef43bac9 100644 --- a/src/nifreeze/data/dmri.py +++ b/src/nifreeze/data/dmri.py @@ -34,11 +34,23 @@ import numpy as np import numpy.typing as npt from nibabel.spatialimages import SpatialImage +from numpy.typing import ArrayLike from typing_extensions import Self -from nifreeze.data.base import BaseDataset, _cmp, _data_repr +from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_dim_size, _has_ndim from nifreeze.utils.ndimage import get_data, load_api +GRADIENT_ABSENCE_ERROR_MSG = "DWI 'gradients' may not be None" +"""DWI initialization gradient absence error message.""" + +GRADIENT_OBJECT_ERROR_MSG = "DWI 'gradients' must be a numpy array." +"""DWI initialization gradient object error message.""" + +GRADIENT_COUNT_MISMATCH_ERROR_MSG = ( + "DWI gradients count ({n_gradients}) does not match dataset volumes ({data_vols})." +) +"""DWI initialization gradient count mismatch error message.""" + DEFAULT_CLIP_PERCENTILE = 75 """Upper percentile threshold for intensity clipping.""" @@ -64,6 +76,150 @@ """Minimum number of nonzero b-values in a DWI dataset.""" +def _check_gradient_shape(value: np.ndarray) -> None: + """Strictly validate a gradients ndarray. + + Validates that ``value`` is a correctly-shaped NumPy array representing + gradients. It performs a sequence of checks and raises :exc:`TypeError` or + :exc:`ValueError` with intentionally explicit messages suitable for use by + higher-level validators. + + The following conditions raise an exception: + - ``value`` is not a 2D :obj:`~numpy.ndarray`. + - ``value`` does not have 4 columns. + + Parameters + ---------- + value : :obj:`~numpy.ndarray` + The candidate gradients array. + + Raises + ------ + :exc:`ValueError` + If ``value`` fails any of the checks described above. + + Examples + -------- + >>> _check_gradient_shape(np.zeros((10, 3))) # valid: does not raise + >>> _check_gradient_shape(np.asarray([[1, 2, 3], [1, 2]]) # raises ValueError + >>> _check_gradient_shape(np.zeros((5,))) # raises ValueError + >>> _check_gradient_shape(np.zeros((2, 6))) # raises ValueError + """ + + if value is None: + raise ValueError(GRADIENT_ABSENCE_ERROR_MSG) + + # Reject ragged/object-dtype arrays explicitly + if value.dtype == object: + raise TypeError(GRADIENT_OBJECT_ERROR_MSG) + + if not _has_ndim(value, 2): + raise ValueError(GRADIENT_NDIM_ERROR_MSG) + + if not _has_dim_size(value, 4): + raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG) + + +def format_gradients(value: ArrayLike) -> np.ndarray: + """Permissive gradient formatter. + + Behavior: + - Converts the incoming ``value`` to a float NumPy array. + - Ensures the result is 2D and that one dimension equals 4. + - If a 2D array has ``shape[0] == 4`` and ``shape[1] != 4``, it will be + transposed so the returned array has ``shape[1] == 4``. + - For 1D inputs of length 4, returns an array shaped ``(1, 4)``. + - Raises exc:`TypeError` for conversion failures and exc:`ValueError` for + shape violations. + + Parameters + ---------- + value : :obj:`ArrayLike` + Input to convert to a :obj:`~numpy.ndarray` of floats. + + Returns + ------- + :obj:`~numpy.ndarray` + A 2D float array with ``shape[1] == 4``. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a float :obj:`~numpy.ndarray`. + exc:`ValueError` + If the converted array is not 2D (after the 1D -> 2D promotion) + or does not have a dimension of size 4 such that the returned array + can be shaped with ``shape[1] == 4``. + + Examples + -------- + >>> format_gradients([0, 0, 0, 1]).shape + (1, 4) + >>> format_gradients(np.zeros((10, 4))).shape + (10, 4) + >>> format_gradients(np.zeros((4, 10))).shape + (10, 4) # transposed so shape[1] == 4 + """ + + if value is None: + raise ValueError(GRADIENT_ABSENCE_ERROR_MSG) + + # Convert to ndarray + if isinstance(value, np.ndarray): + arr = value.astype(float, copy=False) + else: + try: + arr = np.asarray(value, dtype=float) + except (TypeError, ValueError) as exc: + # Conversion failed (e.g. nested ragged objects, non-numeric) + raise TypeError(GRADIENT_OBJECT_ERROR_MSG) from exc + + _check_gradient_shape(arr) + + if arr.shape[1] == 4: + pass + else: + arr = arr.T + + # ToDo + # Call gradient normalization + return arr + + +def validate_gradients(inst: DWI, attr: attrs.Attribute, value: Any) -> None: + """Strict validator for use in attribute validation (e.g. attrs / validators). + + Enforces that ``value`` is a NumPy array and has the expected 2D shape + with 4 columns (``shape[1] == 4``). + + This function is intended for use as an attrs-style validator. + + Raises + ------ + exc:`TypeError` + If ``value`` is not a :obj:`~numpy.ndarray`. + exc:`ValueError`` + If ``value`` is not 2D or its shape does not have 4 columns. + + Parameters + ---------- + inst : :obj:`:obj:`~nifreeze.data.dmri.DWI` + The instance being validated (unused, present for validator signature). + attr : :obj:`~attrs.Attribute` + The attribute being validated (unused, present for validator signature). + value : :obj:`Any` + The value to validate. + """ + + if value is None: + raise ValueError(GRADIENT_ABSENCE_ERROR_MSG) + + if not isinstance(value, np.ndarray): + raise TypeError(GRADIENT_OBJECT_ERROR_MSG) + + _check_gradient_shape(value) + + @attrs.define(slots=True) class DWI(BaseDataset[np.ndarray]): """Data representation structure for dMRI data.""" @@ -72,41 +228,44 @@ class DWI(BaseDataset[np.ndarray]): default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp) ) """A *b=0* reference map, preferably obtained by some smart averaging.""" - gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) + gradients: np.ndarray = attrs.field( + default=None, + repr=_data_repr, + eq=attrs.cmp_using(eq=_cmp), + converter=format_gradients, + validator=validate_gradients, + ) """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components).""" eddy_xfms: list = attrs.field(default=None) """List of transforms to correct for estimated eddy current distortions.""" def __attrs_post_init__(self) -> None: - self._normalize_gradients() - - def _normalize_gradients(self) -> None: - if self.gradients is None: - return + """Enforce basic consistency of required dMRI fields at instantiation + time. - gradients = np.asarray(self.gradients) - if gradients.ndim != 2: - raise ValueError("Gradient table must be a 2D array") + Specifically, the number of gradient directions must match the last + dimension of the data (number of volumes). + """ + # If the data object exists and has a time/volume axis, ensure sizes + # match. n_volumes = None - if self.dataobj is not None: - try: - n_volumes = self.dataobj.shape[-1] - except Exception: # pragma: no cover - extremely defensive - n_volumes = None - - if n_volumes is not None and gradients.shape[0] != n_volumes: - if gradients.shape[1] == n_volumes: - gradients = gradients.T - else: + if getattr(self, "dataobj", None) is not None: + shape = getattr(self.dataobj, "shape", None) + if isinstance(shape, (tuple, list)) and len(shape) >= 1: + try: + n_volumes = int(shape[-1]) + except (TypeError, ValueError): + n_volumes = None + + if n_volumes is not None: + n_gradients = self.gradients.shape[1] + if n_gradients != n_volumes: raise ValueError( - "Gradient table shape does not match the number of diffusion volumes: " - f"expected {n_volumes} rows, found {gradients.shape[0]}" + GRADIENT_COUNT_MISMATCH_ERROR_MSG.format( + n_gradients=n_gradients, data_vols=n_volumes + ) ) - elif n_volumes is None and gradients.shape[1] > gradients.shape[0]: - gradients = gradients.T - - self.gradients = gradients def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]: return (self.gradients[idx, ...],) @@ -315,6 +474,20 @@ def to_nifti( return nii +def _compose_gradients(bvec_file: Path | str, bval_file: Path | str): + bvecs = np.loadtxt(bvec_file, dtype="float32") + if bvecs.ndim == 1: + bvecs = bvecs[np.newaxis, :] + if bvecs.shape[1] != 3 and bvecs.shape[0] == 3: + bvecs = bvecs.T + + bvals = np.loadtxt(bval_file, dtype="float32") + if bvals.ndim > 1: + bvals = np.squeeze(bvals) + + return np.column_stack((bvecs, bvals)) + + def from_nii( filename: Path | str, brainmask_file: Path | str | None = None, @@ -389,35 +562,14 @@ def from_nii( stacklevel=2, ) elif bvec_file and bval_file: - bvecs = np.loadtxt(bvec_file, dtype="float32") - if bvecs.ndim == 1: - bvecs = bvecs[np.newaxis, :] - if bvecs.shape[1] != 3 and bvecs.shape[0] == 3: - bvecs = bvecs.T - - bvals = np.loadtxt(bval_file, dtype="float32") - if bvals.ndim > 1: - bvals = np.squeeze(bvals) - grad = np.column_stack((bvecs, bvals)) + grad = _compose_gradients(bvec_file, bval_file) else: raise RuntimeError( "No gradient data provided. " "Please specify either a gradients_file or (bvec_file & bval_file)." ) - if grad.ndim == 1: - grad = grad[np.newaxis, :] - - if grad.shape[1] < 2: - raise ValueError("Gradient table must have at least two columns (direction + b-value).") - - if grad.shape[1] != 4: - if grad.shape[0] == 4: - grad = grad.T - else: - raise ValueError( - "Gradient table must have four columns (3 direction components and one b-value)." - ) + grad = format_gradients(grad) # 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres # as "DW volumes" if the user wants to store only the high-b volumes here @@ -426,11 +578,14 @@ def from_nii( dwi_obj = DWI( dataobj=fulldata[..., gradmsk], affine=img.affine, - # We'll assign the filtered gradients below. + gradients=grad[ + gradmsk, : + ], # ToDo Duplicate call to format_gradients but cannot do better I think ) - dwi_obj.gradients = grad[gradmsk, :] - dwi_obj._normalize_gradients() + # removing gradients = np.asarray(self.gradients) from _normalize_gradients: + # the annotation does not suggest anything other than arrays: if we want a list of lists, we should type hint that. + # The converter duplicates the checks, and we could skip it in the signature, but I think it is wise to keep it # 4) b=0 volume (bzero) # If the user provided a b0_file, load it diff --git a/test/test_data_dmri.py b/test/test_data_dmri.py index 8eab28e10..fe2028be1 100644 --- a/test/test_data_dmri.py +++ b/test/test_data_dmri.py @@ -22,6 +22,7 @@ # """Unit tests exercising the dMRI data structure.""" +import re from pathlib import Path import nibabel as nb @@ -29,7 +30,15 @@ import pytest from nifreeze.data import load -from nifreeze.data.dmri import DWI, find_shelling_scheme, from_nii, transform_fsl_bvec +from nifreeze.data.dmri import ( + DWI, + GRADIENT_ABSENCE_ERROR_MSG, + GRADIENT_COUNT_MISMATCH_ERROR_MSG, + GRADIENT_SHAPE_ERROR_MSG, + find_shelling_scheme, + from_nii, + transform_fsl_bvec, +) from nifreeze.utils.ndimage import load_api @@ -82,6 +91,36 @@ def test_motion_file_not_implemented(): from_nii("dmri.nii.gz", motion_file="motion.x5") +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +def test_missing_gradients_error(setup_random_uniform_spatial_data): + data, affine = setup_random_uniform_spatial_data + with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG): + DWI(dataobj=data, affine=affine) + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +def test_gradients_shape_error(setup_random_uniform_spatial_data): + data, affine = setup_random_uniform_spatial_data + gradients = np.zeros((3, data.shape[-1])) + with pytest.raises(ValueError, match=re.escape(GRADIENT_SHAPE_ERROR_MSG)): + DWI(dataobj=data, affine=affine, gradients=gradients) + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +def test_gradients_volume_mismatch_error(setup_random_uniform_spatial_data): + data, affine = setup_random_uniform_spatial_data + data_vols = data.shape[-1] + n_gradients = data_vols + 1 + gradients = np.zeros((4, n_gradients)) + with pytest.raises( + ValueError, + match=re.escape( + GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(n_gradients=n_gradients, data_vols=data_vols) + ), + ): + DWI(dataobj=data, affine=affine, gradients=gradients) + + @pytest.mark.parametrize("insert_b0", (False, True)) @pytest.mark.parametrize("rotate_bvecs", (False, True)) def test_load(datadir, tmp_path, insert_b0, rotate_bvecs): # noqa: C901 @@ -959,3 +998,115 @@ def test_transform_fsl_bvec(b_ijk, zooms, flips, axis_order, origin, angles): f"Expected {rotated_b_ijk}, got {test_b_ijk} for b_ijk={b_ijk}, " f"zooms={zooms}, origin={origin}, angles={angles}" ) + + +##################################################################### +def _get_checker(): + # Support either name used in different iterations: _check_gradient_shape or _check_ndarray_shape + return getattr(gradients, "_check_gradient_shape", None) or getattr( + gradients, "_check_ndarray_shape", None + ) + + +@pytest.mark.parametrize( + "value, expected_exc, expected_msg", + [ + (None, ValueError, GRADIENT_ABSENCE_ERROR_MSG), + (np.array([[1], [2]], dtype=object), TypeError, GRADIENT_OBJECT_ERROR_MSG), + (np.zeros((3,)), ValueError, GRADIENT_NDIM_ERROR_MSG), + (np.zeros((2, 3)), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG), + ], +) +def test_check_gradient_shape_raises_for_invalid_inputs(value, expected_exc, expected_msg): + checker = _get_checker() + assert checker is not None, "no checker function found in gradients module" + with pytest.raises(expected_exc, match=expected_msg): + checker(value) + + +@pytest.mark.parametrize( + "bad_input, expected_exc, expected_msg", + [ + (3.14, ValueError, GRADIENT_NDIM_ERROR_MSG), + ([[1, 2], [3, 4, 5]], (TypeError, ValueError), GRADIENT_OBJECT_ERROR_MSG), # Ragged + ], +) +def test_format_gradients_rejects_scalar_and_ragged( + monkeypatch, bad_input, expected_exc, expected_msg +): + # normalize no-op + monkeypatch.setattr(gradients, "_normalize_gradients", lambda x: x) + + with pytest.raises(expected_exc, match=expected_msg): + format_gradients(bad_input) + + +@pytest.mark.parametrize( + "arr", + [ + np.zeros((2, 4)), + np.zeros((4, 2)), + ], +) +def test_check_gradient_shape_accepts_valid_oriented_arrays(arr): + checker = _get_checker() + assert checker is not None, "no checker function found in gradients module" + # Should not raise + checker(arr) + + +@pytest.mark.parametrize( + "input_value, expected_shape, expected_values", + [ + ( + [[1, 2, 3, 4], [5, 6, 7, 8]], # list of lists -> (2,4) + (2, 4), + np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=float), + ), + ( + np.arange(8).reshape(4, 2).astype(float), # (4,2) should be transposed to (2,4) + (2, 4), + np.arange(8).reshape(4, 2).astype(float).T, + ), + ], +) +def test_format_gradients_converts_list_and_transposes_when_needed( + input_value, expected_shape, expected_values, monkeypatch +): + # Make normalize a no-op so tests can inspect returned array + monkeypatch.setattr(gradients, "_normalize_gradients", lambda x: x) + + out = format_gradients(input_value) + assert isinstance(out, np.ndarray) + assert out.shape == expected_shape + assert np.allclose(out, expected_values) + + +@pytest.mark.parametrize( + "value, expected_exc, expected_msg", + [ + (None, ValueError, GRADIENT_ABSENCE_ERROR_MSG), + ([1, 2, 3, 4], GRADIENT_OBJECT_ERROR_MSG), # Non-ndarray + ( + np.array([[1], [2]], dtype=object), + TypeError, + GRADIENT_OBJECT_ERROR_MSG, + ), # ndarray with object dtype + (np.zeros((3,)), ValueError, GRADIENT_NDIM_ERROR_MSG), # Wrong ndim + ( + np.zeros((2, 3)), + ValueError, + GRADIENT_EXPECTED_COLUMNS_ERROR_MSG, + ), # Missing dimension size 4 + ], +) +def test_validate_gradients_errors(value, expected_exc, expected_msg): + validator = validate_gradients + with pytest.raises(expected_exc, match=expected_msg): + validator(None, "gradients", value) + + +def test_validate_gradients_accepts_valid_oriented_arrays(): + # Valid arrays (either orientation) -> no exception + validate_gradients(None, "gradients", np.zeros((2, 4))) + validate_gradients(None, "gradients", np.zeros((4, 2)))