Skip to content

Commit eecba2a

Browse files
committed
REF: Refactor gradient data checks in DWI data class
Refactor gradient data checks in DWI data class: PR #325 introduced the row-major convention for gradient data in NiFreeze. However, gradient loading was not tested thoroughly. This resulted in some execution flows that would not guarantee a row-major internal convention or would crash under some circumstances. This commit refactors the gradient data checks and adds thorough testing: - Define all error or warning messages as global variables so that they can be checked exactly in tests. - Ensure that gradients conform to the row-major convention when instantiating the DWI class directly. This allows to separate the gradient reformatting from the dimensionality check with the DWI volume sequence. This simplifies the flow, as the gradient reformatting to the row-major convention does not depend on the number of volumes in the DWI sequence. Also, this makes the flow more consistent with the refactored checks of the NIfTI file-based loading utility function (`from_nii`). - Ensure that the gradients conform to the row-major convention immediately after loading the gradients file in the NIfTI file-based loading utility function (`from_nii`). As opposed to the previous implementation, this allows to load the gradients from a file where data follows either column-major or row-major convention. e.g. In the previous implementation the `if grad.shape[1] < 2:` was making an assumption about the layout and/or one that was wrong because we require 4 columns (direction (x,y,z) + b-value) or rows (if in column-major). The new implementation simplifies the execution flow.
1 parent 997df63 commit eecba2a

File tree

2 files changed

+414
-36
lines changed

2 files changed

+414
-36
lines changed

src/nifreeze/data/dmri.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,25 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = "Gradient table shape does not match the number of diffusion volumes: expected {n_volumes} rows, found {n_gradients}."
67+
"""dMRI volume count vs. gradient count mismatch error message."""
68+
69+
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = "Both a gradients table file and b-vec/val files are defined; ignoring b-vec/val files in favor of the gradients_file."
70+
""""dMRI gradient file priority warning message."""
71+
72+
GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
73+
"""dMRI gradient dimensionality error message."""
74+
75+
GRADIENT_DATA_MISSING_ERROR = (
76+
"No gradient data provided. Please specify either a gradients_file or (bvec_file & bval_file)."
77+
)
78+
"""dMRI missing gradient data error message."""
79+
80+
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
81+
"Gradient table must have four columns (3 direction components and one b-value)."
82+
)
83+
"""dMRI gradient expected columns error message."""
84+
6685

6786
@attrs.define(slots=True)
6887
class DWI(BaseDataset[np.ndarray]):
@@ -86,7 +105,13 @@ def _normalize_gradients(self) -> None:
86105

87106
gradients = np.asarray(self.gradients)
88107
if gradients.ndim != 2:
89-
raise ValueError("Gradient table must be a 2D array")
108+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
109+
if gradients.shape[1] == 4:
110+
pass
111+
elif gradients.shape[0] == 4:
112+
gradients = gradients.T
113+
else:
114+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
90115

91116
n_volumes = None
92117
if self.dataobj is not None:
@@ -96,15 +121,11 @@ def _normalize_gradients(self) -> None:
96121
n_volumes = None
97122

98123
if n_volumes is not None and gradients.shape[0] != n_volumes:
99-
if gradients.shape[1] == n_volumes:
100-
gradients = gradients.T
101-
else:
102-
raise ValueError(
103-
"Gradient table shape does not match the number of diffusion volumes: "
104-
f"expected {n_volumes} rows, found {gradients.shape[0]}"
124+
raise ValueError(
125+
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR.format(
126+
n_volumes=n_volumes, n_gradients=gradients.shape[0]
105127
)
106-
elif n_volumes is None and gradients.shape[1] > gradients.shape[0]:
107-
gradients = gradients.T
128+
)
108129

109130
self.gradients = gradients
110131

@@ -383,11 +404,15 @@ def from_nii(
383404
if gradients_file:
384405
grad = np.loadtxt(gradients_file, dtype="float32")
385406
if bvec_file and bval_file:
386-
warn(
387-
"Both a gradients table file and b-vec/val files are defined; "
388-
"ignoring b-vec/val files in favor of the gradients_file.",
389-
stacklevel=2,
390-
)
407+
warn(GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG, stacklevel=2)
408+
if grad.ndim != 2:
409+
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
410+
if grad.shape[1] == 4:
411+
pass
412+
elif grad.shape[0] == 4:
413+
grad = grad.T
414+
else:
415+
raise ValueError(GRADIENT_EXPECTED_COLUMNS_ERROR_MSG)
391416
elif bvec_file and bval_file:
392417
bvecs = np.loadtxt(bvec_file, dtype="float32")
393418
if bvecs.ndim == 1:
@@ -400,24 +425,7 @@ def from_nii(
400425
bvals = np.squeeze(bvals)
401426
grad = np.column_stack((bvecs, bvals))
402427
else:
403-
raise RuntimeError(
404-
"No gradient data provided. "
405-
"Please specify either a gradients_file or (bvec_file & bval_file)."
406-
)
407-
408-
if grad.ndim == 1:
409-
grad = grad[np.newaxis, :]
410-
411-
if grad.shape[1] < 2:
412-
raise ValueError("Gradient table must have at least two columns (direction + b-value).")
413-
414-
if grad.shape[1] != 4:
415-
if grad.shape[0] == 4:
416-
grad = grad.T
417-
else:
418-
raise ValueError(
419-
"Gradient table must have four columns (3 direction components and one b-value)."
420-
)
428+
raise RuntimeError(GRADIENT_DATA_MISSING_ERROR)
421429

422430
# 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
423431
# as "DW volumes" if the user wants to store only the high-b volumes here

0 commit comments

Comments
 (0)