Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 207 additions & 52 deletions src/nifreeze/data/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,23 @@
import numpy as np
import numpy.typing as npt
from nibabel.spatialimages import SpatialImage
from numpy.typing import ArrayLike
from typing_extensions import Self

from nifreeze.data.base import BaseDataset, _cmp, _data_repr
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_dim_size, _has_ndim
from nifreeze.utils.ndimage import get_data, load_api

GRADIENT_ABSENCE_ERROR_MSG = "DWI 'gradients' may not be None"
"""DWI initialization gradient absence error message."""

GRADIENT_OBJECT_ERROR_MSG = "DWI 'gradients' must be a numpy array."
"""DWI initialization gradient object error message."""

GRADIENT_COUNT_MISMATCH_ERROR_MSG = (
"DWI gradients count ({n_gradients}) does not match dataset volumes ({data_vols})."
)
"""DWI initialization gradient count mismatch error message."""

DEFAULT_CLIP_PERCENTILE = 75
"""Upper percentile threshold for intensity clipping."""

Expand All @@ -64,6 +76,150 @@
"""Minimum number of nonzero b-values in a DWI dataset."""


def _check_gradient_shape(value: np.ndarray) -> None:
"""Strictly validate a gradients ndarray.

Validates that ``value`` is a correctly-shaped NumPy array representing
gradients. It performs a sequence of checks and raises :exc:`TypeError` or
:exc:`ValueError` with intentionally explicit messages suitable for use by
higher-level validators.

The following conditions raise an exception:
- ``value`` is not a 2D :obj:`~numpy.ndarray`.
- ``value`` does not have 4 columns.

Parameters
----------
value : :obj:`~numpy.ndarray`
The candidate gradients array.

Raises
------
:exc:`ValueError`
If ``value`` fails any of the checks described above.

Examples
--------
>>> _check_gradient_shape(np.zeros((10, 3))) # valid: does not raise
>>> _check_gradient_shape(np.asarray([[1, 2, 3], [1, 2]]) # raises ValueError
>>> _check_gradient_shape(np.zeros((5,))) # raises ValueError
>>> _check_gradient_shape(np.zeros((2, 6))) # raises ValueError
"""

if value is None:
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)

# Reject ragged/object-dtype arrays explicitly
if value.dtype == object:
raise TypeError(GRADIENT_OBJECT_ERROR_MSG)

if not _has_ndim(value, 2):
raise ValueError(GRADIENT_NDIM_ERROR_MSG)

if not _has_dim_size(value, 4):
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)


def format_gradients(value: ArrayLike) -> np.ndarray:
"""Permissive gradient formatter.

Behavior:
- Converts the incoming ``value`` to a float NumPy array.
- Ensures the result is 2D and that one dimension equals 4.
- If a 2D array has ``shape[0] == 4`` and ``shape[1] != 4``, it will be
transposed so the returned array has ``shape[1] == 4``.
- For 1D inputs of length 4, returns an array shaped ``(1, 4)``.
- Raises exc:`TypeError` for conversion failures and exc:`ValueError` for
shape violations.

Parameters
----------
value : :obj:`ArrayLike`
Input to convert to a :obj:`~numpy.ndarray` of floats.

Returns
-------
:obj:`~numpy.ndarray`
A 2D float array with ``shape[1] == 4``.

Raises
------
exc:`TypeError`
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
exc:`ValueError`
If the converted array is not 2D (after the 1D -> 2D promotion)
or does not have a dimension of size 4 such that the returned array
can be shaped with ``shape[1] == 4``.

Examples
--------
>>> format_gradients([0, 0, 0, 1]).shape
(1, 4)
>>> format_gradients(np.zeros((10, 4))).shape
(10, 4)
>>> format_gradients(np.zeros((4, 10))).shape
(10, 4) # transposed so shape[1] == 4
"""

if value is None:
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)

# Convert to ndarray
if isinstance(value, np.ndarray):
arr = value.astype(float, copy=False)
else:
try:
arr = np.asarray(value, dtype=float)
except (TypeError, ValueError) as exc:
# Conversion failed (e.g. nested ragged objects, non-numeric)
raise TypeError(GRADIENT_OBJECT_ERROR_MSG) from exc

_check_gradient_shape(arr)

if arr.shape[1] == 4:
pass
else:
arr = arr.T

# ToDo
# Call gradient normalization
return arr


def validate_gradients(inst: DWI, attr: attrs.Attribute, value: Any) -> None:
"""Strict validator for use in attribute validation (e.g. attrs / validators).

Enforces that ``value`` is a NumPy array and has the expected 2D shape
with 4 columns (``shape[1] == 4``).

This function is intended for use as an attrs-style validator.

Raises
------
exc:`TypeError`
If ``value`` is not a :obj:`~numpy.ndarray`.
exc:`ValueError``
If ``value`` is not 2D or its shape does not have 4 columns.

Parameters
----------
inst : :obj:`:obj:`~nifreeze.data.dmri.DWI`
The instance being validated (unused, present for validator signature).
attr : :obj:`~attrs.Attribute`
The attribute being validated (unused, present for validator signature).
value : :obj:`Any`
The value to validate.
"""

if value is None:
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)

if not isinstance(value, np.ndarray):
raise TypeError(GRADIENT_OBJECT_ERROR_MSG)

_check_gradient_shape(value)


@attrs.define(slots=True)
class DWI(BaseDataset[np.ndarray]):
"""Data representation structure for dMRI data."""
Expand All @@ -72,41 +228,44 @@ class DWI(BaseDataset[np.ndarray]):
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
)
"""A *b=0* reference map, preferably obtained by some smart averaging."""
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
gradients: np.ndarray = attrs.field(
default=None,
repr=_data_repr,
eq=attrs.cmp_using(eq=_cmp),
converter=format_gradients,
validator=validate_gradients,
)
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
eddy_xfms: list = attrs.field(default=None)
"""List of transforms to correct for estimated eddy current distortions."""

def __attrs_post_init__(self) -> None:
self._normalize_gradients()

def _normalize_gradients(self) -> None:
if self.gradients is None:
return
"""Enforce basic consistency of required dMRI fields at instantiation
time.

gradients = np.asarray(self.gradients)
if gradients.ndim != 2:
raise ValueError("Gradient table must be a 2D array")
Specifically, the number of gradient directions must match the last
dimension of the data (number of volumes).
"""

# If the data object exists and has a time/volume axis, ensure sizes
# match.
n_volumes = None
if self.dataobj is not None:
try:
n_volumes = self.dataobj.shape[-1]
except Exception: # pragma: no cover - extremely defensive
n_volumes = None

if n_volumes is not None and gradients.shape[0] != n_volumes:
if gradients.shape[1] == n_volumes:
gradients = gradients.T
else:
if getattr(self, "dataobj", None) is not None:
shape = getattr(self.dataobj, "shape", None)
if isinstance(shape, (tuple, list)) and len(shape) >= 1:
try:
n_volumes = int(shape[-1])
except (TypeError, ValueError):
n_volumes = None

if n_volumes is not None:
n_gradients = self.gradients.shape[1]
if n_gradients != n_volumes:
raise ValueError(
"Gradient table shape does not match the number of diffusion volumes: "
f"expected {n_volumes} rows, found {gradients.shape[0]}"
GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(
n_gradients=n_gradients, data_vols=n_volumes
)
)
elif n_volumes is None and gradients.shape[1] > gradients.shape[0]:
gradients = gradients.T

self.gradients = gradients

def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
return (self.gradients[idx, ...],)
Expand Down Expand Up @@ -315,6 +474,20 @@ def to_nifti(
return nii


def _compose_gradients(bvec_file: Path | str, bval_file: Path | str):
bvecs = np.loadtxt(bvec_file, dtype="float32")
if bvecs.ndim == 1:
bvecs = bvecs[np.newaxis, :]
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
bvecs = bvecs.T

bvals = np.loadtxt(bval_file, dtype="float32")
if bvals.ndim > 1:
bvals = np.squeeze(bvals)

return np.column_stack((bvecs, bvals))


def from_nii(
filename: Path | str,
brainmask_file: Path | str | None = None,
Expand Down Expand Up @@ -389,35 +562,14 @@ def from_nii(
stacklevel=2,
)
elif bvec_file and bval_file:
bvecs = np.loadtxt(bvec_file, dtype="float32")
if bvecs.ndim == 1:
bvecs = bvecs[np.newaxis, :]
if bvecs.shape[1] != 3 and bvecs.shape[0] == 3:
bvecs = bvecs.T

bvals = np.loadtxt(bval_file, dtype="float32")
if bvals.ndim > 1:
bvals = np.squeeze(bvals)
grad = np.column_stack((bvecs, bvals))
grad = _compose_gradients(bvec_file, bval_file)
else:
raise RuntimeError(
"No gradient data provided. "
"Please specify either a gradients_file or (bvec_file & bval_file)."
)

if grad.ndim == 1:
grad = grad[np.newaxis, :]

if grad.shape[1] < 2:
raise ValueError("Gradient table must have at least two columns (direction + b-value).")

if grad.shape[1] != 4:
if grad.shape[0] == 4:
grad = grad.T
else:
raise ValueError(
"Gradient table must have four columns (3 direction components and one b-value)."
)
grad = format_gradients(grad)

# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
# as "DW volumes" if the user wants to store only the high-b volumes here
Expand All @@ -426,11 +578,14 @@ def from_nii(
dwi_obj = DWI(
dataobj=fulldata[..., gradmsk],
affine=img.affine,
# We'll assign the filtered gradients below.
gradients=grad[
gradmsk, :
], # ToDo Duplicate call to format_gradients but cannot do better I think
)

dwi_obj.gradients = grad[gradmsk, :]
dwi_obj._normalize_gradients()
# removing gradients = np.asarray(self.gradients) from _normalize_gradients:
# the annotation does not suggest anything other than arrays: if we want a list of lists, we should type hint that.
# The converter duplicates the checks, and we could skip it in the signature, but I think it is wise to keep it

# 4) b=0 volume (bzero)
# If the user provided a b0_file, load it
Expand Down
Loading
Loading