Skip to content

Commit 6a81e9a

Browse files
committed
enh: add tests to docstring and integrate to overhaul
Closes: #332.
1 parent 2026fe4 commit 6a81e9a

File tree

6 files changed

+383
-373
lines changed

6 files changed

+383
-373
lines changed

src/nifreeze/data/dmri.py

Lines changed: 38 additions & 335 deletions
Original file line numberDiff line numberDiff line change
@@ -37,164 +37,21 @@
3737
from typing_extensions import Self
3838

3939
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
40-
from nifreeze.utils.ndimage import get_data, load_api
41-
42-
DEFAULT_CLIP_PERCENTILE = 75
43-
"""Upper percentile threshold for intensity clipping."""
44-
45-
DEFAULT_MIN_S0 = 1e-5
46-
"""Minimum value when considering the :math:`S_{0}` DWI signal."""
47-
48-
DEFAULT_MAX_S0 = 1.0
49-
"""Maximum value when considering the :math:`S_{0}` DWI signal."""
50-
51-
DEFAULT_LOWB_THRESHOLD = 50
52-
"""The lower bound for the b-value so that the orientation is considered a DW volume."""
53-
54-
DEFAULT_GRADIENT_EPS = 1e-8
55-
"""Epsilon value for b-vector normalization."""
56-
57-
DEFAULT_HIGHB_THRESHOLD = 8000
58-
"""A b-value cap for DWI data."""
59-
60-
DEFAULT_NUM_BINS = 15
61-
"""Number of bins to classify b-values."""
62-
63-
DEFAULT_MULTISHELL_BIN_COUNT_THR = 7
64-
"""Default bin count to consider a multishell scheme."""
65-
66-
DTI_MIN_ORIENTATIONS = 6
67-
"""Minimum number of nonzero b-values in a DWI dataset."""
68-
69-
GRADIENT_ABSENCE_ERROR_MSG = "Gradient table may not be None."
70-
"""Gradient absence error message."""
71-
72-
GRADIENT_OBJECT_ERROR_MSG = "Gradient table must be a numeric homogeneous array-like object"
73-
"""Gradient object error message."""
74-
75-
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
76-
Gradient table shape does not match the number of diffusion volumes: \
77-
expected {n_volumes} rows, found {n_gradients}."""
78-
"""dMRI volume count vs. gradient count mismatch error message."""
79-
80-
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = """\
81-
Both a gradients table file and b-vec/val files are defined; \
82-
ignoring b-vec/val files in favor of the gradients_file."""
83-
""""dMRI gradient file priority warning message."""
84-
85-
GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
86-
"""dMRI gradient dimensionality error message."""
87-
88-
GRADIENT_DATA_MISSING_ERROR = "No gradient data provided."
89-
"""dMRI missing gradient data error message."""
90-
91-
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
92-
"Gradient table must have four columns (3 direction components and one b-value)."
40+
from nifreeze.data.dmriutils import (
41+
DEFAULT_HIGHB_THRESHOLD,
42+
DEFAULT_LOWB_THRESHOLD,
43+
DEFAULT_MULTISHELL_BIN_COUNT_THR,
44+
DEFAULT_NUM_BINS,
45+
DTI_MIN_ORIENTATIONS,
46+
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG,
47+
GRADIENT_DATA_MISSING_ERROR,
48+
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG,
49+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR,
50+
find_shelling_scheme,
51+
format_gradients,
52+
transform_fsl_bvec,
9353
)
94-
"""dMRI gradient expected columns error message."""
95-
96-
97-
def format_gradients(
98-
value: npt.ArrayLike | None,
99-
) -> np.ndarray | None:
100-
"""
101-
Validate and orient gradient tables to row-major convention.
102-
103-
Parameters
104-
----------
105-
value : :obj:`ArrayLike`
106-
The value to format.
107-
108-
Returns
109-
-------
110-
:obj:`~numpy.ndarray`
111-
Row-major convention gradient table.
112-
113-
Raises
114-
------
115-
exc:`ValueError`
116-
If ``value`` is not a 2D :obj:`~numpy.ndarray` (``value.ndim != 2``).
117-
118-
Examples
119-
--------
120-
Passing an already well-formed table returns the data unchanged::
121-
122-
>>> format_gradients(
123-
... [
124-
... [1, 0, 0, 0],
125-
... [0, 1, 0, 1000],
126-
... [0, 0, 1, 2000],
127-
... [0, 0, 0, 0],
128-
... [0, 0, 0, 1000],
129-
... ]
130-
... )
131-
array([[ 1, 0, 0, 0],
132-
[ 0, 1, 0, 1000],
133-
[ 0, 0, 1, 2000],
134-
[ 0, 0, 0, 0],
135-
[ 0, 0, 0, 1000]])
136-
137-
Column-major inputs are automatically transposed when an expected
138-
number of diffusion volumes is provided::
139-
140-
>>> format_gradients(
141-
... [[1, 0], [0, 1], [0, 0], [1000, 2000]],
142-
... )
143-
array([[ 1, 0, 0, 1000],
144-
[ 0, 1, 0, 2000]])
145-
146-
Gradient tables must always have two dimensions::
147-
148-
>>> format_gradients([0, 1, 0, 1000])
149-
Traceback (most recent call last):
150-
...
151-
ValueError: Gradient table must be a 2D array
152-
153-
Gradient tables must have a regular shape::
154-
155-
>>> format_gradients([[1, 2], [3, 4, 5]])
156-
Traceback (most recent call last):
157-
...
158-
TypeError: Gradient table must be a numeric homogeneous array-like object
159-
160-
Gradient tables must always have two dimensions::
161-
162-
>>> format_gradients([0, 1, 0, 1000])
163-
Traceback (most recent call last):
164-
...
165-
ValueError: Gradient table must be a 2D array
166-
167-
"""
168-
169-
if value is None:
170-
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
171-
172-
try:
173-
formatted = np.asarray(value, dtype=float)
174-
except (TypeError, ValueError) as exc:
175-
# Conversion failed (e.g. nested ragged objects, non-numeric)
176-
raise TypeError(GRADIENT_OBJECT_ERROR_MSG) from exc
177-
178-
if formatted.ndim != 2:
179-
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
180-
181-
# If the numeric values are all integers, preserve integer dtype
182-
if np.all(np.isfinite(formatted)) and np.allclose(formatted, np.round(formatted)):
183-
formatted = formatted.astype(int)
184-
185-
# Transpose if column-major
186-
formatted = formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
187-
188-
# Normalize b-vectors in-place
189-
bvecs = formatted[:, :3]
190-
norms = np.linalg.norm(bvecs, axis=1)
191-
mask = norms > DEFAULT_GRADIENT_EPS
192-
if np.any(mask):
193-
formatted[mask, :3] = bvecs[mask] / norms[mask, None] # Norm b-vectors
194-
formatted[mask, 3] *= norms[mask] # Scale b-values by norm
195-
formatted[~mask, :] = 0.0 # Zero-out small b-vectors
196-
197-
return formatted
54+
from nifreeze.utils.ndimage import get_data, load_api
19855

19956

20057
def validate_gradients(
@@ -216,10 +73,33 @@ def validate_gradients(
21673
The attribute being validated; attr.name is used in the error message.
21774
value : :obj:`~npt.NDArray`
21875
The value to validate.
76+
77+
Raises
78+
------
79+
:exc:`ValueError`
80+
If the gradient table is invalid.
81+
82+
83+
Examples
84+
--------
85+
Non-finite inputs are rejected::
86+
87+
>>> validate_gradients(None, None, [[np.inf, 0.0, 0.0, 1000]])
88+
Traceback (most recent call last):
89+
...
90+
ValueError: Gradient table contains NaN or infinite values.
91+
>>> validate_gradients(None, None, [[np.nan, 0.0, 0.0, 1000]])
92+
Traceback (most recent call last):
93+
...
94+
ValueError: Gradient table contains NaN or infinite values.
95+
21996
"""
220-
if value.shape[1] != 4:
97+
if np.shape(value)[1] != 4:
22198
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
22299

100+
if not np.all(np.isfinite(value)):
101+
raise ValueError("Gradient table contains NaN or infinite values.")
102+
223103

224104
@attrs.define(slots=True)
225105
class DWI(BaseDataset[np.ndarray]):
@@ -567,180 +447,3 @@ def from_nii(
567447
bzero=b0_data,
568448
brainmask=brainmask_data,
569449
)
570-
571-
572-
def find_shelling_scheme(
573-
bvals: np.ndarray,
574-
num_bins: int = DEFAULT_NUM_BINS,
575-
multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR,
576-
bval_cap: float = DEFAULT_HIGHB_THRESHOLD,
577-
) -> tuple[str, list[npt.NDArray[np.floating]], list[np.floating]]:
578-
"""
579-
Find the shelling scheme on the given b-values.
580-
581-
Computes the histogram of the b-values according to ``num_bins``
582-
and depending on the nonempty bin count, classify the shelling scheme
583-
as single-shell if they are 2 (low-b and a shell); multi-shell if they are
584-
below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise.
585-
586-
Parameters
587-
----------
588-
bvals : :obj:`list` or :obj:`~numpy.ndarray`
589-
List or array of b-values.
590-
num_bins : :obj:`int`, optional
591-
Number of bins.
592-
multishell_nonempty_bin_count_thr : :obj:`int`, optional
593-
Bin count to consider a multi-shell scheme.
594-
bval_cap : :obj:`float`, optional
595-
Maximum b-value to be considered in a multi-shell scheme.
596-
597-
Returns
598-
-------
599-
scheme : :obj:`str`
600-
Shelling scheme.
601-
bval_groups : :obj:`list`
602-
List of grouped b-values.
603-
bval_estimated : :obj:`list`
604-
List of 'estimated' b-values as the median value of each b-value group.
605-
606-
"""
607-
608-
# Bin the b-values: use -1 as the lower bound to be able to appropriately
609-
# include b0 values
610-
hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap)))
611-
612-
# Collect values in each bin
613-
bval_groups = []
614-
bval_estimated = []
615-
for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False):
616-
# Add only if a nonempty b-values mask
617-
if (mask := (bvals > lower) & (bvals <= upper)).sum():
618-
bval_groups.append(bvals[mask])
619-
bval_estimated.append(np.median(bvals[mask]))
620-
621-
nonempty_bins = len(bval_groups)
622-
623-
if nonempty_bins < 2:
624-
raise ValueError("DWI must have at least one high-b shell")
625-
626-
if nonempty_bins == 2:
627-
scheme = "single-shell"
628-
elif nonempty_bins < multishell_nonempty_bin_count_thr:
629-
scheme = "multi-shell"
630-
else:
631-
scheme = "DSI"
632-
633-
return scheme, bval_groups, bval_estimated
634-
635-
636-
def transform_fsl_bvec(
637-
b_ijk: np.ndarray, xfm: np.ndarray, imaffine: np.ndarray, invert: bool = False
638-
) -> np.ndarray:
639-
"""
640-
Transform a b-vector from the original space to the new space defined by the affine.
641-
642-
Parameters
643-
----------
644-
b_ijk : :obj:`~numpy.ndarray`
645-
The b-vector in FSL/DIPY conventions (i.e., voxel coordinates).
646-
xfm : :obj:`~numpy.ndarray`
647-
The affine transformation to apply.
648-
Please note that this is the inverse of the head-motion-correction affine,
649-
which maps coordinates from the realigned space to the moved (scan) space.
650-
In this case, we want to move the b-vector from the moved (scan) space into
651-
the realigned space.
652-
imaffine : :obj:`~numpy.ndarray`
653-
The image's affine, to convert.
654-
invert : :obj:`bool`, optional
655-
If ``True``, the transformation will be inverted.
656-
657-
Returns
658-
-------
659-
:obj:`~numpy.ndarray`
660-
The transformed b-vector in voxel coordinates (FSL/DIPY).
661-
662-
"""
663-
xfm = np.linalg.inv(xfm) if invert else xfm.copy()
664-
665-
# Go from world coordinates (xfm) to voxel coordinates
666-
ijk2ijk_xfm = np.linalg.inv(imaffine) @ xfm @ imaffine
667-
668-
return ijk2ijk_xfm[:3, :3] @ b_ijk[:3]
669-
670-
671-
def normalize_gradients(value: np.ndarray, eps: float = 1e-8, copy: bool = True) -> np.ndarray:
672-
"""Normalize b-vectors in arrays of common shapes.
673-
674-
Parameters
675-
----------
676-
value : :obj:`~numpy.ndarray`
677-
Input array with shape one of:
678-
- (N, 3) : rows are b-vector components (e.g., [gx gy gz])
679-
- (N, 4) : first 3 columns are b-vector components (e.g., [gx gy gz b])
680-
- (3, N) : columns are b-vector components (e.g., [gx gy gz].T)
681-
- (4, N) : first 3 rows are b-vector components (e.g., [gx gy gz b].T)
682-
- (3,) or (1,3) or (3,1) : single b-vector
683-
Columns are checked first to disambiguate Nx3/Nx4 cases.
684-
eps : float, optional
685-
Threshold below which a vector is considered zero and left unchanged.
686-
copy : bool, optional
687-
If ``True``, returns a new array; modify in-place otherwise.
688-
689-
Returns
690-
-------
691-
out : :obj:`~numpy.ndarray`
692-
Array with the same shape as ``value`` with each 3-component b-vector
693-
normalized.
694-
"""
695-
arr = np.asarray(value, dtype=float)
696-
697-
# 1D single vector
698-
if arr.ndim == 1:
699-
if arr.size != 3:
700-
raise ValueError(GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG)
701-
norm = np.linalg.norm(arr)
702-
if norm > eps:
703-
if copy:
704-
return arr / norm
705-
else:
706-
# Perform in-place normalization on the array view
707-
arr[:] = arr / norm
708-
return arr
709-
else:
710-
return arr.copy() if copy else arr
711-
712-
if arr.ndim != 2:
713-
raise ValueError(GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG)
714-
715-
rows, cols = arr.shape
716-
717-
# Prepare output (copy or in-place)
718-
normalized_arr = arr.copy() if copy else arr
719-
720-
# Determine where the 3-component vectors live and create a (N, 3) view
721-
# Check columns first to make Nx3/Nx4 deterministic
722-
if cols == 4:
723-
# Nx4: first 3 columns are b-vectors components, last are b-values
724-
vecs = normalized_arr[:, :3] # shape (N, 3)
725-
elif cols == 3:
726-
# Nx3: rows are vectors
727-
vecs = normalized_arr # shape (N, 3)
728-
elif rows == 4:
729-
# 4xN: first 3 rows are b-vector components, last row are b-values
730-
# Create a (N, 3) view by transposing first 3 rows
731-
vecs = normalized_arr[:3, :].T # shape (N, 3)
732-
elif rows == 3:
733-
# 3xN: columns are vectors: normalize per-column
734-
vecs = normalized_arr.T # shape (N, 3)
735-
else:
736-
raise ValueError(
737-
GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG.format(shape=arr.shape)
738-
)
739-
740-
# Normalize in-place on vecs (which is a view into output)
741-
norms = np.linalg.norm(vecs, axis=1)
742-
mask = norms > eps
743-
if np.any(mask):
744-
vecs[mask] = vecs[mask] / norms[mask, None]
745-
746-
return normalized_arr

0 commit comments

Comments
 (0)