Skip to content

Commit b5acdeb

Browse files
committed
ENH: Add gradient unit-sphere normalization function
Add gradient unit-sphere normalization function.
1 parent 6723229 commit b5acdeb

File tree

2 files changed

+266
-1
lines changed

2 files changed

+266
-1
lines changed

src/nifreeze/data/dmri.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66+
GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG = "Input must be 1D or 2D array."
67+
"""Gradient normalization shape error message."""
68+
69+
GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG = "1D input must have length 3 to be a single b-vector."
70+
"""Gradient normalization length error message."""
71+
72+
GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG = (
73+
"Unrecognized shape {shape}. Expect Nx3, Nx4, 3xN, or 4xN (or 1D length-3)"
74+
)
75+
"""Gradient normalization unrecognized error message."""
76+
6677

6778
@attrs.define(slots=True)
6879
class DWI(BaseDataset[np.ndarray]):
@@ -106,6 +117,7 @@ def _normalize_gradients(self) -> None:
106117
elif n_volumes is None and gradients.shape[1] > gradients.shape[0]:
107118
gradients = gradients.T
108119

120+
normalize_gradients(gradients, copy=False)
109121
self.gradients = gradients
110122

111123
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
@@ -552,3 +564,81 @@ def transform_fsl_bvec(
552564
ijk2ijk_xfm = np.linalg.inv(imaffine) @ xfm @ imaffine
553565

554566
return ijk2ijk_xfm[:3, :3] @ b_ijk[:3]
567+
568+
569+
def normalize_gradients(value: np.ndarray, eps: float = 1e-8, copy: bool = True) -> np.ndarray:
570+
"""Normalize b-vectors in arrays of common shapes.
571+
572+
Parameters
573+
----------
574+
value : :obj:`~numpy.ndarray`
575+
Input array with shape one of:
576+
- (N, 3) : rows are b-vector components (e.g., [gx gy gz])
577+
- (N, 4) : first 3 columns are b-vector components (e.g., [gx gy gz b])
578+
- (3, N) : columns are b-vector components (e.g., [gx gy gz].T)
579+
- (4, N) : first 3 rows are b-vector components (e.g., [gx gy gz b].T)
580+
- (3,) or (1,3) or (3,1) : single b-vector
581+
Columns are checked first to disambiguate Nx3/Nx4 cases.
582+
eps : float, optional
583+
Threshold below which a vector is considered zero and left unchanged.
584+
copy : bool, optional
585+
If ``True``, returns a new array; modify in-place otherwise.
586+
587+
Returns
588+
-------
589+
out : :obj:`~numpy.ndarray`
590+
Array with the same shape as ``value`` with each 3-component b-vector
591+
normalized.
592+
"""
593+
arr = np.asarray(value, dtype=float)
594+
595+
# 1D single vector
596+
if arr.ndim == 1:
597+
if arr.size != 3:
598+
raise ValueError(GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG)
599+
norm = np.linalg.norm(arr)
600+
if norm > eps:
601+
if copy:
602+
return arr / norm
603+
else:
604+
# Perform in-place normalization on the array view
605+
arr[:] = arr / norm
606+
return arr
607+
else:
608+
return arr.copy() if copy else arr
609+
610+
if arr.ndim != 2:
611+
raise ValueError(GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG)
612+
613+
rows, cols = arr.shape
614+
615+
# Prepare output (copy or in-place)
616+
normalized_arr = arr.copy() if copy else arr
617+
618+
# Determine where the 3-component vectors live and create a (N, 3) view
619+
# Check columns first to make Nx3/Nx4 deterministic
620+
if cols == 4:
621+
# Nx4: first 3 columns are b-vectors components, last are b-values
622+
vecs = normalized_arr[:, :3] # shape (N, 3)
623+
elif cols == 3:
624+
# Nx3: rows are vectors
625+
vecs = normalized_arr # shape (N, 3)
626+
elif rows == 4:
627+
# 4xN: first 3 rows are b-vector components, last row are b-values
628+
# Create a (N, 3) view by transposing first 3 rows
629+
vecs = normalized_arr[:3, :].T # shape (N, 3)
630+
elif rows == 3:
631+
# 3xN: columns are vectors: normalize per-column
632+
vecs = normalized_arr.T # shape (N, 3)
633+
else:
634+
raise ValueError(
635+
GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG.format(shape=arr.shape)
636+
)
637+
638+
# Normalize in-place on vecs (which is a view into output)
639+
norms = np.linalg.norm(vecs, axis=1)
640+
mask = norms > eps
641+
if np.any(mask):
642+
vecs[mask] = vecs[mask] / norms[mask, None]
643+
644+
return normalized_arr

test/test_data_dmri.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,50 @@
2222
#
2323
"""Unit tests exercising the dMRI data structure."""
2424

25+
import re
2526
from pathlib import Path
27+
from string import Formatter
2628

2729
import nibabel as nb
2830
import numpy as np
2931
import pytest
3032

3133
from nifreeze.data import load
32-
from nifreeze.data.dmri import DWI, find_shelling_scheme, from_nii, transform_fsl_bvec
34+
from nifreeze.data.dmri import (
35+
DWI,
36+
GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG,
37+
GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG,
38+
GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG,
39+
find_shelling_scheme,
40+
from_nii,
41+
normalize_gradients,
42+
transform_fsl_bvec,
43+
)
3344
from nifreeze.utils.ndimage import load_api
3445

3546

47+
def _template_has_field(template: str, field_name: str | None = None) -> bool:
48+
"""Return True if `template` contains a format field.
49+
If `field_name` is provided, return True only if that named field appears.
50+
51+
This uses Formatter.parse() so it recognizes real format fields and
52+
ignores literal substrings that merely look like "{shape}".
53+
"""
54+
formatter = Formatter()
55+
for _literal_text, field, _format_spec, _conversion in formatter.parse(template):
56+
if field is None:
57+
# no field in this segment
58+
continue
59+
# field can be '' (positional {}), 'shape', or complex like 'shape[0]' or 'obj.attr'
60+
if field_name is None:
61+
return True
62+
# Compare the base name before any attribute/indexing syntax
63+
base = field.split(".", 1)[0].split("[", 1)[0]
64+
if base == field_name:
65+
return True
66+
return False
67+
68+
3669
def _dwi_data_to_nifti(
3770
dwi_dataobj,
3871
affine,
@@ -959,3 +992,145 @@ def test_transform_fsl_bvec(b_ijk, zooms, flips, axis_order, origin, angles):
959992
f"Expected {rotated_b_ijk}, got {test_b_ijk} for b_ijk={b_ijk}, "
960993
f"zooms={zooms}, origin={origin}, angles={angles}"
961994
)
995+
996+
997+
@pytest.mark.parametrize(
998+
"shape, expected_msg_template",
999+
[
1000+
# 1D but wrong length
1001+
((4,), GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG),
1002+
# ndim != 1 and != 2
1003+
((2, 2, 2), GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG),
1004+
# 2D but unrecognized shape (neither Nx3/Nx4 nor 3xN/4xN)
1005+
((2, 2), GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG),
1006+
],
1007+
)
1008+
def test_normalize_gradients_exceptions(shape, expected_msg_template):
1009+
arr = np.zeros(shape, dtype=float)
1010+
if _template_has_field(expected_msg_template, "shape"):
1011+
expected_msg = expected_msg_template.format(shape=shape)
1012+
else:
1013+
expected_msg = expected_msg_template
1014+
1015+
with pytest.raises(ValueError, match=re.escape(expected_msg)):
1016+
normalize_gradients(arr)
1017+
1018+
1019+
@pytest.mark.parametrize(
1020+
"arr, expected",
1021+
[
1022+
# Nx3: rows are b-vectors (e.g., [gx gy gz])
1023+
(
1024+
np.array([[1, 0, 0], [0, 2, 0], [0, 0, 0]], float),
1025+
np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], float),
1026+
),
1027+
(
1028+
np.array([[2, 3, 1], [7, 5, 6], [0, 0, 0]], float),
1029+
np.array(
1030+
[
1031+
[2.0 / np.sqrt(14), 3.0 / np.sqrt(14), 1.0 / np.sqrt(14)],
1032+
[7.0 / np.sqrt(110), 5.0 / np.sqrt(110), 6.0 / np.sqrt(110)],
1033+
[0.0, 0.0, 0.0],
1034+
],
1035+
float,
1036+
),
1037+
),
1038+
# Nx4: first 3 columns are b-vectors (e.g., [gx gy gz b])
1039+
(
1040+
np.array([[1, 0, 0, 0], [0, 2, 0, 1000], [0, 0, 0, 0]], float),
1041+
np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1000.0], [0.0, 0.0, 0.0, 0.0]], float),
1042+
),
1043+
(
1044+
np.array([[0, 0, 0, 0], [1, 0, 2, 1000], [1, 2, 1, 1000]], float),
1045+
np.array(
1046+
[
1047+
[0, 0, 0, 0],
1048+
[1.0 / np.sqrt(5), 0, 2.0 / np.sqrt(5), 1000],
1049+
[1.0 / np.sqrt(6), 2.0 / np.sqrt(6), 1.0 / np.sqrt(6), 1000],
1050+
],
1051+
float,
1052+
),
1053+
),
1054+
(
1055+
np.array([[4.0, 2.0, 1.0, 250.0]], float),
1056+
np.array([[4.0 / np.sqrt(21), 2.0 / np.sqrt(21), 1.0 / np.sqrt(21), 250.0]], float),
1057+
),
1058+
# 3xN: columns are b-vectors (e.g., [gx gy gz].T)
1059+
(
1060+
np.array([[1, 0], [0, 2], [0, 0]], float),
1061+
np.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]], float),
1062+
),
1063+
(
1064+
np.array([[8.0, 0.0], [1.0, 0.0], [6.0, 0.0]], float),
1065+
np.array(
1066+
[[8.0 / np.sqrt(101), 0.0], [1.0 / np.sqrt(101), 0.0], [6.0 / np.sqrt(101), 0.0]],
1067+
float,
1068+
),
1069+
),
1070+
# 4xN: first 3 rows are b-vectors (e.g., [gx gy gz b].T)
1071+
(
1072+
np.array([[1, 0], [0, 2], [0, 0], [0, 1000]], float),
1073+
np.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 1000.0]], float),
1074+
),
1075+
(
1076+
np.array([[6.0, 0.0], [8.0, 0.0], [0.0, 0.0], [5.0, 200.0]], float),
1077+
np.array([[0.6, 0.0], [0.8, 0.0], [0.0, 0.0], [5.0, 200.0]], float),
1078+
),
1079+
# 1D single vector
1080+
(np.array([3, 0, 0], float), np.array([1.0, 0.0, 0.0], float)),
1081+
(np.array([3.0, 4.0, 0.0], float), np.array([0.6, 0.8, 0.0], float)),
1082+
],
1083+
)
1084+
def test_normalize_gradients_shapes(arr, expected):
1085+
"""Normalize several common bvec layouts and compare to expected output."""
1086+
obtained = normalize_gradients(arr) # default copy=True
1087+
1088+
assert obtained.shape == expected.shape
1089+
assert np.allclose(obtained, expected)
1090+
1091+
1092+
@pytest.mark.parametrize(
1093+
"arr, idx_check, expected_row",
1094+
[
1095+
# Nx3 in-place: ensure modification and returned object identity
1096+
(np.array([[0, 3, 0], [0, 0, 0]], float), (0,), np.array([0.0, 1.0, 0.0])),
1097+
# 1D single vector in-place
1098+
(np.array([3, 0, 0], float), None, np.array([1.0, 0.0, 0.0])),
1099+
],
1100+
)
1101+
def test_normalize_gradients_inplace(arr, idx_check, expected_row):
1102+
"""
1103+
Ensure copy=False modifies the provided ndarray in-place and that returned
1104+
object is the same when appropriate. For 1D arrays the returned object
1105+
should be the same object when copy=False.
1106+
"""
1107+
arr_copy = arr.copy()
1108+
obtained = normalize_gradients(arr_copy, copy=False)
1109+
1110+
# returned object must be the exact same ndarray when copy=False
1111+
assert obtained is arr_copy
1112+
1113+
if idx_check is None:
1114+
# 1D vector: compare whole array
1115+
assert np.allclose(arr_copy, expected_row)
1116+
else:
1117+
# For multi-row arrays, check the indicated row(s)
1118+
# idx_check is a tuple of row indices to check (here only first row)
1119+
for i, expected in zip(idx_check, [expected_row]):
1120+
assert np.allclose(arr_copy[i], expected)
1121+
1122+
1123+
def test_normalize_gradients_zero_vectors_preserved_and_norms():
1124+
"""Check that near-zero vectors are left unchanged and non-zero are unit."""
1125+
a = np.array([[1e-12, 0, 0], [0, 2, 0], [0, 0, 0]], float)
1126+
eps = 1e-8
1127+
obtained = normalize_gradients(a, eps=eps)
1128+
1129+
# First row is near-zero with norm < eps -> preserved (close to original)
1130+
assert np.allclose(obtained[0], a[0])
1131+
1132+
# Second row normalized to unit length
1133+
assert np.allclose(np.linalg.norm(obtained[1]), 1.0)
1134+
1135+
# Last row is exactly zero and preserved
1136+
assert np.allclose(obtained[2], np.zeros(3))

0 commit comments

Comments
 (0)