Skip to content

Commit 6cd6e78

Browse files
committed
REF: Refactor gradient data checks in DWI data class
Refactor gradient data checks in DWI data class: PR #325 introduced the row-major convention for gradient data in NiFreeze. However, gradient loading was not tested thoroughly. This resulted in some execution flows that would not guarantee a row-major internal convention or would crash under some circumstances. This commit refactors the gradient data checks and adds thorough testing: - Define all error or warning messages as global variables so that they can be checked exactly in tests. - Ensure that gradients conform to the row-major convention when instantiating the DWI class directly. This allows to separate the gradient reformatting from the dimensionality check with the DWI volume sequence. This simplifies the flow, as the gradient reformatting to the row-major convention does not depend on the number of volumes in the DWI sequence. Also, this makes the flow more consistent with the refactored checks of the NIfTI file-based loading utility function (`from_nii`). - Ensure that the gradients conform to the row-major convention immediately after loading the gradients file in the NIfTI file-based loading utility function (`from_nii`). As opposed to the previous implementation, this allows to load the gradients from a file where data follows either column-major or row-major convention. e.g. In the previous implementation the `if grad.shape[1] < 2:` was making an assumption about the layout and/or one that was wrong because we require 4 columns (direction (x,y,z) + b-value) or rows (if in column-major). The new implementation simplifies the execution flow.
1 parent c729280 commit 6cd6e78

File tree

2 files changed

+413
-36
lines changed

2 files changed

+413
-36
lines changed

src/nifreeze/data/dmri.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,25 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = "Gradient table shape does not match the number of diffusion volumes: expected {n_volumes} rows, found {n_gradients}."
67+
"""dMRI volume count vs. gradient count mismatch error message."""
68+
69+
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = "Both a gradients table file and b-vec/val files are defined; ignoring b-vec/val files in favor of the gradients_file."
70+
""""dMRI gradient file priority warning message."""
71+
72+
GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
73+
"""dMRI gradient dimensionality error message."""
74+
75+
GRADIENT_DATA_MISSING_ERROR = (
76+
"No gradient data provided. Please specify either a gradients_file or (bvec_file & bval_file)."
77+
)
78+
"""dMRI missing gradient data error message."""
79+
80+
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
81+
"Gradient table must have four columns (3 direction components and one b-value)."
82+
)
83+
"""dMRI gradient expected columns error message."""
84+
6685

6786
@attrs.define(slots=True)
6887
class DWI(BaseDataset[np.ndarray]):
@@ -84,7 +103,13 @@ def _normalize_gradients(self) -> None:
84103

85104
gradients = np.asarray(self.gradients)
86105
if gradients.ndim != 2:
87-
raise ValueError("Gradient table must be a 2D array")
106+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
107+
if gradients.shape[1] == 4:
108+
pass
109+
elif gradients.shape[0] == 4:
110+
gradients = gradients.T
111+
else:
112+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
88113

89114
n_volumes = None
90115
if self.dataobj is not None:
@@ -94,15 +119,11 @@ def _normalize_gradients(self) -> None:
94119
n_volumes = None
95120

96121
if n_volumes is not None and gradients.shape[0] != n_volumes:
97-
if gradients.shape[1] == n_volumes:
98-
gradients = gradients.T
99-
else:
100-
raise ValueError(
101-
"Gradient table shape does not match the number of diffusion volumes: "
102-
f"expected {n_volumes} rows, found {gradients.shape[0]}"
122+
raise ValueError(
123+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR.format(
124+
n_volumes=n_volumes, n_gradients=gradients.shape[0]
103125
)
104-
elif n_volumes is None and gradients.shape[1] > gradients.shape[0]:
105-
gradients = gradients.T
126+
)
106127

107128
self.gradients = gradients
108129

@@ -381,11 +402,15 @@ def from_nii(
381402
if gradients_file:
382403
grad = np.loadtxt(gradients_file, dtype="float32")
383404
if bvec_file and bval_file:
384-
warn(
385-
"Both a gradients table file and b-vec/val files are defined; "
386-
"ignoring b-vec/val files in favor of the gradients_file.",
387-
stacklevel=2,
388-
)
405+
warn(GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG, stacklevel=2)
406+
if grad.ndim != 2:
407+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
408+
if grad.shape[1] == 4:
409+
pass
410+
elif grad.shape[0] == 4:
411+
grad = grad.T
412+
else:
413+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
389414
elif bvec_file and bval_file:
390415
bvecs = np.loadtxt(bvec_file, dtype="float32")
391416
if bvecs.ndim == 1:
@@ -398,24 +423,7 @@ def from_nii(
398423
bvals = np.squeeze(bvals)
399424
grad = np.column_stack((bvecs, bvals))
400425
else:
401-
raise RuntimeError(
402-
"No gradient data provided. "
403-
"Please specify either a gradients_file or (bvec_file & bval_file)."
404-
)
405-
406-
if grad.ndim == 1:
407-
grad = grad[np.newaxis, :]
408-
409-
if grad.shape[1] < 2:
410-
raise ValueError("Gradient table must have at least two columns (direction + b-value).")
411-
412-
if grad.shape[1] != 4:
413-
if grad.shape[0] == 4:
414-
grad = grad.T
415-
else:
416-
raise ValueError(
417-
"Gradient table must have four columns (3 direction components and one b-value)."
418-
)
426+
raise RuntimeError(GRADIENT_DATA_MISSING_ERROR)
419427

420428
# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
421429
# as "DW volumes" if the user wants to store only the high-b volumes here

0 commit comments

Comments
 (0)