Skip to content

Commit f1ebf08

Browse files
committed
REF: Improve gradient formatting function robustness
Improve gradient formatting function robustness by adopting a more defensive approach: - Raise if gradients are `None`. - Raise if gradients are not a numeric homogeneous array-like object. Add the corresponding tests. Take advantage of the commit to test the `validate_gradients` function.
1 parent f182efc commit f1ebf08

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

src/nifreeze/data/dmri.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
DTI_MIN_ORIENTATIONS = 6
6464
"""Minimum number of nonzero b-values in a DWI dataset."""
6565

66+
GRADIENT_ABSENCE_ERROR_MSG = "Gradient table may not be None."
67+
"""Gradient absence error message."""
68+
69+
GRADIENT_OBJECT_ERROR_MSG = "Gradient table must be a numeric homogeneous array-like object"
70+
"""Gradient object error message."""
71+
6672
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
6773
Gradient table shape does not match the number of diffusion volumes: \
6874
expected {n_volumes} rows, found {n_gradients}."""
@@ -141,12 +147,38 @@ def format_gradients(
141147
...
142148
ValueError: Gradient table must be a 2D array
143149
150+
Gradient tables must have a regular shape::
151+
152+
>>> format_gradients([[1, 2], [3, 4, 5]])
153+
Traceback (most recent call last):
154+
...
155+
TypeError: Gradient table must be a numeric homogeneous array-like object
156+
157+
Gradient tables must always have two dimensions::
158+
159+
>>> format_gradients([0, 1, 0, 1000])
160+
Traceback (most recent call last):
161+
...
162+
ValueError: Gradient table must be a 2D array
163+
144164
"""
145165

146-
formatted = np.asarray(value)
166+
if value is None:
167+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
168+
169+
try:
170+
formatted = np.asarray(value, dtype=float)
171+
except (TypeError, ValueError) as exc:
172+
# Conversion failed (e.g. nested ragged objects, non-numeric)
173+
raise TypeError(GRADIENT_OBJECT_ERROR_MSG) from exc
174+
147175
if formatted.ndim != 2:
148176
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
149177

178+
# If the numeric values are all integers, preserve integer dtype
179+
if np.all(np.isfinite(formatted)) and np.allclose(formatted, np.round(formatted)):
180+
formatted = formatted.astype(int)
181+
150182
# Transpose if column-major
151183
return formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
152184

test/test_data_dmri.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Unit tests exercising the dMRI data structure."""
2424

25+
import attrs
2526
import re
2627
from pathlib import Path
2728

@@ -32,14 +33,18 @@
3233
from nifreeze.data import load
3334
from nifreeze.data.dmri import (
3435
DWI,
36+
GRADIENT_ABSENCE_ERROR_MSG,
3537
GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG,
3638
GRADIENT_DATA_MISSING_ERROR,
3739
GRADIENT_EXPECTED_COLUMNS_ERROR_MSG,
3840
GRADIENT_NDIM_ERROR_MSG,
41+
GRADIENT_OBJECT_ERROR_MSG,
3942
GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR,
4043
find_shelling_scheme,
44+
format_gradients,
4145
from_nii,
4246
transform_fsl_bvec,
47+
validate_gradients,
4348
)
4449
from nifreeze.utils.ndimage import load_api
4550

@@ -88,6 +93,73 @@ def test_main(datadir):
8893
assert isinstance(load(input_file), DWI)
8994

9095

96+
@pytest.mark.parametrize(
97+
"value, expected_exc, expected_msg",
98+
[
99+
(np.array([[1], [2]], dtype=object), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG),
100+
(np.zeros((2, 3)), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG),
101+
],
102+
)
103+
def test_validate_gradients(monkeypatch, value, expected_exc, expected_msg):
104+
monkeypatch.setattr(DWI, "__init__", lambda self, *a, **k: None)
105+
inst = DWI()
106+
dummy_attr = attrs.fields(DWI).gradients
107+
with pytest.raises(expected_exc, match=re.escape(str(expected_msg))):
108+
validate_gradients(inst, dummy_attr, value)
109+
110+
111+
@pytest.mark.parametrize(
112+
"value, expected_exc, expected_msg",
113+
[
114+
(None, ValueError, GRADIENT_ABSENCE_ERROR_MSG),
115+
(3.14, ValueError, GRADIENT_NDIM_ERROR_MSG),
116+
([1, 2, 3, 4], ValueError, GRADIENT_NDIM_ERROR_MSG),
117+
(np.arange(24).reshape(4, 3, 2), ValueError, GRADIENT_NDIM_ERROR_MSG),
118+
([[1, 2], [3, 4, 5]], (TypeError, ValueError), GRADIENT_OBJECT_ERROR_MSG), # Ragged
119+
],
120+
)
121+
def test_format_gradients_errors(value, expected_exc, expected_msg):
122+
with pytest.raises(expected_exc, match=str(expected_msg)):
123+
format_gradients(value)
124+
125+
126+
@pytest.mark.parametrize(
127+
"value, expect_transpose",
128+
[
129+
# 2D arrays where first dim == 4 and second dim == 4 -> NO transpose
130+
(np.arange(16).reshape(4, 4), False),
131+
# 2D arrays where first dim == 4 and second dim != 4 -> transpose
132+
(np.arange(12).reshape(4, 3), True),
133+
(np.arange(4).reshape(4, 1), True),
134+
(np.empty((4, 0)), True), # zero columns -> still triggers transpose
135+
(np.arange(20).reshape(4, 5), True),
136+
# 2D arrays where first dim != 4 -> NO transpose
137+
(np.arange(12).reshape(3, 4), False),
138+
(np.arange(20).reshape(5, 4), False),
139+
# List of lists
140+
([[1, 2, 3, 4], [5, 6, 7, 8]], False),
141+
([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], True),
142+
],
143+
)
144+
def test_format_gradients_basic(value, expect_transpose):
145+
obtained = format_gradients(value)
146+
147+
assert isinstance(obtained, np.ndarray)
148+
if expect_transpose:
149+
assert obtained.shape == np.asarray(value).T.shape
150+
assert np.allclose(obtained, np.asarray(value).T)
151+
else:
152+
assert obtained.shape == np.asarray(value).shape
153+
assert np.allclose(obtained, np.asarray(value))
154+
155+
156+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
157+
def test_gradients_absence_error(setup_random_uniform_spatial_data):
158+
data, affine = setup_random_uniform_spatial_data
159+
with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG):
160+
DWI(dataobj=data, affine=affine)
161+
162+
91163
@pytest.mark.random_gtab_data(10, (1000, 2000), 2)
92164
@pytest.mark.random_dwi_data(50, (34, 36, 24), True)
93165
@pytest.mark.parametrize("row_major_gradients", (False, True))

0 commit comments

Comments
 (0)