Skip to content

Commit 2026fe4

Browse files
jhlegarretaoesteban
authored andcommitted
ENH: Add gradient unit-sphere normalization function
Add gradient unit-sphere normalization function.
1 parent da1e127 commit 2026fe4

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

src/nifreeze/data/dmri.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
DEFAULT_LOWB_THRESHOLD = 50
5252
"""The lower bound for the b-value so that the orientation is considered a DW volume."""
5353

54+
DEFAULT_GRADIENT_EPS = 1e-8
55+
"""Epsilon value for b-vector normalization."""
56+
5457
DEFAULT_HIGHB_THRESHOLD = 8000
5558
"""A b-value cap for DWI data."""
5659

@@ -180,7 +183,18 @@ def format_gradients(
180183
formatted = formatted.astype(int)
181184

182185
# Transpose if column-major
183-
return formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
186+
formatted = formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
187+
188+
# Normalize b-vectors in-place
189+
bvecs = formatted[:, :3]
190+
norms = np.linalg.norm(bvecs, axis=1)
191+
mask = norms > DEFAULT_GRADIENT_EPS
192+
if np.any(mask):
193+
formatted[mask, :3] = bvecs[mask] / norms[mask, None] # Norm b-vectors
194+
formatted[mask, 3] *= norms[mask] # Scale b-values by norm
195+
formatted[~mask, :] = 0.0 # Zero-out small b-vectors
196+
197+
return formatted
184198

185199

186200
def validate_gradients(
@@ -652,3 +666,81 @@ def transform_fsl_bvec(
652666
ijk2ijk_xfm = np.linalg.inv(imaffine) @ xfm @ imaffine
653667

654668
return ijk2ijk_xfm[:3, :3] @ b_ijk[:3]
669+
670+
671+
def normalize_gradients(value: np.ndarray, eps: float = 1e-8, copy: bool = True) -> np.ndarray:
672+
"""Normalize b-vectors in arrays of common shapes.
673+
674+
Parameters
675+
----------
676+
value : :obj:`~numpy.ndarray`
677+
Input array with shape one of:
678+
- (N, 3) : rows are b-vector components (e.g., [gx gy gz])
679+
- (N, 4) : first 3 columns are b-vector components (e.g., [gx gy gz b])
680+
- (3, N) : columns are b-vector components (e.g., [gx gy gz].T)
681+
- (4, N) : first 3 rows are b-vector components (e.g., [gx gy gz b].T)
682+
- (3,) or (1,3) or (3,1) : single b-vector
683+
Columns are checked first to disambiguate Nx3/Nx4 cases.
684+
eps : float, optional
685+
Threshold below which a vector is considered zero and left unchanged.
686+
copy : bool, optional
687+
If ``True``, returns a new array; modify in-place otherwise.
688+
689+
Returns
690+
-------
691+
out : :obj:`~numpy.ndarray`
692+
Array with the same shape as ``value`` with each 3-component b-vector
693+
normalized.
694+
"""
695+
arr = np.asarray(value, dtype=float)
696+
697+
# 1D single vector
698+
if arr.ndim == 1:
699+
if arr.size != 3:
700+
raise ValueError(GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG)
701+
norm = np.linalg.norm(arr)
702+
if norm > eps:
703+
if copy:
704+
return arr / norm
705+
else:
706+
# Perform in-place normalization on the array view
707+
arr[:] = arr / norm
708+
return arr
709+
else:
710+
return arr.copy() if copy else arr
711+
712+
if arr.ndim != 2:
713+
raise ValueError(GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG)
714+
715+
rows, cols = arr.shape
716+
717+
# Prepare output (copy or in-place)
718+
normalized_arr = arr.copy() if copy else arr
719+
720+
# Determine where the 3-component vectors live and create a (N, 3) view
721+
# Check columns first to make Nx3/Nx4 deterministic
722+
if cols == 4:
723+
# Nx4: first 3 columns are b-vectors components, last are b-values
724+
vecs = normalized_arr[:, :3] # shape (N, 3)
725+
elif cols == 3:
726+
# Nx3: rows are vectors
727+
vecs = normalized_arr # shape (N, 3)
728+
elif rows == 4:
729+
# 4xN: first 3 rows are b-vector components, last row are b-values
730+
# Create a (N, 3) view by transposing first 3 rows
731+
vecs = normalized_arr[:3, :].T # shape (N, 3)
732+
elif rows == 3:
733+
# 3xN: columns are vectors: normalize per-column
734+
vecs = normalized_arr.T # shape (N, 3)
735+
else:
736+
raise ValueError(
737+
GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG.format(shape=arr.shape)
738+
)
739+
740+
# Normalize in-place on vecs (which is a view into output)
741+
norms = np.linalg.norm(vecs, axis=1)
742+
mask = norms > eps
743+
if np.any(mask):
744+
vecs[mask] = vecs[mask] / norms[mask, None]
745+
746+
return normalized_arr

test/test_data_dmri.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import re
2626
from pathlib import Path
27+
from string import Formatter
2728

2829
import attrs
2930
import nibabel as nb
@@ -49,6 +50,28 @@
4950
from nifreeze.utils.ndimage import load_api
5051

5152

53+
def _template_has_field(template: str, field_name: str | None = None) -> bool:
54+
"""Return True if `template` contains a format field.
55+
If `field_name` is provided, return True only if that named field appears.
56+
57+
This uses Formatter.parse() so it recognizes real format fields and
58+
ignores literal substrings that merely look like "{shape}".
59+
"""
60+
formatter = Formatter()
61+
for _literal_text, field, _format_spec, _conversion in formatter.parse(template):
62+
if field is None:
63+
# no field in this segment
64+
continue
65+
# field can be '' (positional {}), 'shape', or complex like 'shape[0]' or 'obj.attr'
66+
if field_name is None:
67+
return True
68+
# Compare the base name before any attribute/indexing syntax
69+
base = field.split(".", 1)[0].split("[", 1)[0]
70+
if base == field_name:
71+
return True
72+
return False
73+
74+
5275
def _dwi_data_to_nifti(
5376
dwi_dataobj,
5477
affine,

0 commit comments

Comments
 (0)