Skip to content

Commit c782ac6

Browse files
committed
REF: Improve gradient formatting function robustness
Improve gradient formatting function robustness by adopting a more defensive approach: - Raise if gradients are `None`. - Raise if gradients are not a numeric homogeneous array-like object. Add the corresponding tests. Take advantage of the commit to test the `validate_gradients` function.
1 parent f182efc commit c782ac6

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

src/nifreeze/data/dmri.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66+
GRADIENT_ABSENCE_ERROR_MSG = "Gradient table may not be None."
67+
"""Gradient absence error message."""
68+
69+
GRADIENT_OBJECT_ERROR_MSG = "Gradient table must be a numeric homogeneous array-like object."
70+
"""Gradient object error message."""
71+
6672
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
6773
Gradient table shape does not match the number of diffusion volumes: \
6874
expected {n_volumes} rows, found {n_gradients}."""
@@ -141,9 +147,31 @@ def format_gradients(
141147
...
142148
ValueError: Gradient table must be a 2D array
143149
150+
Gradient tables must have a regular shape::
151+
152+
>>> format_gradients([[1, 2], [3, 4, 5]])
153+
Traceback (most recent call last):
154+
...
155+
TypeError: Gradient table must be a numeric homogeneous array-like object
156+
157+
Gradient tables must always have two dimensions::
158+
159+
>>> format_gradients([0, 1, 0, 1000])
160+
Traceback (most recent call last):
161+
...
162+
ValueError: Gradient table must be a 2D array
163+
144164
"""
145165

146-
formatted = np.asarray(value)
166+
if value is None:
167+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
168+
169+
try:
170+
formatted = np.asarray(value, dtype=float)
171+
except (TypeError, ValueError) as exc:
172+
# Conversion failed (e.g. nested ragged objects, non-numeric)
173+
raise TypeError(GRADIENT_OBJECT_ERROR_MSG) from exc
174+
147175
if formatted.ndim != 2:
148176
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
149177

test/test_data_dmri.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,18 @@
3232
from nifreeze.data import load
3333
from nifreeze.data.dmri import (
3434
DWI,
35+
GRADIENT_ABSENCE_ERROR_MSG,
3536
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG,
3637
GRADIENT_DATA_MISSING_ERROR,
3738
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG,
3839
GRADIENT_NDIM_ERROR_MSG,
40+
GRADIENT_OBJECT_ERROR_MSG,
3941
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR,
4042
find_shelling_scheme,
43+
format_gradients,
4144
from_nii,
4245
transform_fsl_bvec,
46+
validate_gradients,
4347
)
4448
from nifreeze.utils.ndimage import load_api
4549

@@ -88,6 +92,70 @@ def test_main(datadir):
8892
assert isinstance(load(input_file), DWI)
8993

9094

95+
@pytest.mark.parametrize(
96+
"value, expected_exc, expected_msg",
97+
[
98+
(np.array([[1], [2]], dtype=object), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG),
99+
(np.zeros((2, 3)), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG),
100+
],
101+
)
102+
def test_validate_gradients(value, expected_exc, expected_msg):
103+
with pytest.raises(expected_exc, match=re.escape(str(expected_msg))):
104+
validate_gradients(None, "gradients", value)
105+
106+
107+
@pytest.mark.parametrize(
108+
"value, expected_exc, expected_msg",
109+
[
110+
(None, ValueError, GRADIENT_ABSENCE_ERROR_MSG),
111+
(3.14, ValueError, GRADIENT_NDIM_ERROR_MSG),
112+
([1, 2, 3, 4], ValueError, GRADIENT_NDIM_ERROR_MSG),
113+
(np.arange(24).reshape(4, 3, 2), ValueError, GRADIENT_NDIM_ERROR_MSG),
114+
([[1, 2], [3, 4, 5]], (TypeError, ValueError), GRADIENT_OBJECT_ERROR_MSG), # Ragged
115+
],
116+
)
117+
def test_format_gradients_errors(value, expected_exc, expected_msg):
118+
with pytest.raises(expected_exc, match=str(expected_msg)):
119+
format_gradients(value)
120+
121+
122+
@pytest.mark.parametrize(
123+
"value, expect_transpose",
124+
[
125+
# 2D arrays where first dim == 4 and second dim == 4 -> NO transpose
126+
(np.arange(16).reshape(4, 4), False),
127+
# 2D arrays where first dim == 4 and second dim != 4 -> transpose
128+
(np.arange(12).reshape(4, 3), True),
129+
(np.arange(4).reshape(4, 1), True),
130+
(np.empty((4, 0)), True), # zero columns -> still triggers transpose
131+
(np.arange(20).reshape(4, 5), True),
132+
# 2D arrays where first dim != 4 -> NO transpose
133+
(np.arange(12).reshape(3, 4), False),
134+
(np.arange(20).reshape(5, 4), False),
135+
# List of lists
136+
([[1, 2, 3, 4], [5, 6, 7, 8]], False),
137+
([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], True),
138+
],
139+
)
140+
def test_format_gradients_basic(value, expect_transpose):
141+
obtained = format_gradients(value)
142+
143+
assert isinstance(obtained, np.ndarray)
144+
if expect_transpose:
145+
assert obtained.shape == np.asarray(value).T.shape
146+
assert np.allclose(obtained, np.asarray(value).T)
147+
else:
148+
assert obtained.shape == np.asarray(value).shape
149+
assert np.allclose(obtained, np.asarray(value))
150+
151+
152+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
153+
def test_gradients_absence_error(setup_random_uniform_spatial_data):
154+
data, affine = setup_random_uniform_spatial_data
155+
with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG):
156+
DWI(dataobj=data, affine=affine)
157+
158+
91159
@pytest.mark.random_gtab_data(10, (1000, 2000), 2)
92160
@pytest.mark.random_dwi_data(50, (34, 36, 24), True)
93161
@pytest.mark.parametrize("row_major_gradients", (False, True))

0 commit comments

Comments
 (0)