|
25 | 25 | import re |
26 | 26 | from pathlib import Path |
27 | 27 |
|
| 28 | +import attrs |
28 | 29 | import nibabel as nb |
29 | 30 | import numpy as np |
30 | 31 | import pytest |
31 | 32 |
|
32 | 33 | from nifreeze.data import load |
33 | 34 | from nifreeze.data.dmri import ( |
34 | 35 | DWI, |
| 36 | + GRADIENT_ABSENCE_ERROR_MSG, |
35 | 37 | GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG, |
36 | 38 | GRADIENT_DATA_MISSING_ERROR, |
37 | 39 | GRADIENT_EXPECTED_COLUMNS_ERROR_MSG, |
38 | 40 | GRADIENT_NDIM_ERROR_MSG, |
| 41 | + GRADIENT_OBJECT_ERROR_MSG, |
39 | 42 | GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR, |
40 | 43 | find_shelling_scheme, |
| 44 | + format_gradients, |
41 | 45 | from_nii, |
42 | 46 | transform_fsl_bvec, |
| 47 | + validate_gradients, |
43 | 48 | ) |
44 | 49 | from nifreeze.utils.ndimage import load_api |
45 | 50 |
|
@@ -88,6 +93,73 @@ def test_main(datadir): |
88 | 93 | assert isinstance(load(input_file), DWI) |
89 | 94 |
|
90 | 95 |
|
| 96 | +@pytest.mark.parametrize( |
| 97 | + "value, expected_exc, expected_msg", |
| 98 | + [ |
| 99 | + (np.array([[1], [2]], dtype=object), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG), |
| 100 | + (np.zeros((2, 3)), ValueError, GRADIENT_EXPECTED_COLUMNS_ERROR_MSG), |
| 101 | + ], |
| 102 | +) |
| 103 | +def test_validate_gradients(monkeypatch, value, expected_exc, expected_msg): |
| 104 | + monkeypatch.setattr(DWI, "__init__", lambda self, *a, **k: None) |
| 105 | + inst = DWI() |
| 106 | + dummy_attr = attrs.fields(DWI).gradients |
| 107 | + with pytest.raises(expected_exc, match=re.escape(str(expected_msg))): |
| 108 | + validate_gradients(inst, dummy_attr, value) |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.parametrize( |
| 112 | + "value, expected_exc, expected_msg", |
| 113 | + [ |
| 114 | + (None, ValueError, GRADIENT_ABSENCE_ERROR_MSG), |
| 115 | + (3.14, ValueError, GRADIENT_NDIM_ERROR_MSG), |
| 116 | + ([1, 2, 3, 4], ValueError, GRADIENT_NDIM_ERROR_MSG), |
| 117 | + (np.arange(24).reshape(4, 3, 2), ValueError, GRADIENT_NDIM_ERROR_MSG), |
| 118 | + ([[1, 2], [3, 4, 5]], (TypeError, ValueError), GRADIENT_OBJECT_ERROR_MSG), # Ragged |
| 119 | + ], |
| 120 | +) |
| 121 | +def test_format_gradients_errors(value, expected_exc, expected_msg): |
| 122 | + with pytest.raises(expected_exc, match=str(expected_msg)): |
| 123 | + format_gradients(value) |
| 124 | + |
| 125 | + |
| 126 | +@pytest.mark.parametrize( |
| 127 | + "value, expect_transpose", |
| 128 | + [ |
| 129 | + # 2D arrays where first dim == 4 and second dim == 4 -> NO transpose |
| 130 | + (np.arange(16).reshape(4, 4), False), |
| 131 | + # 2D arrays where first dim == 4 and second dim != 4 -> transpose |
| 132 | + (np.arange(12).reshape(4, 3), True), |
| 133 | + (np.arange(4).reshape(4, 1), True), |
| 134 | + (np.empty((4, 0)), True), # zero columns -> still triggers transpose |
| 135 | + (np.arange(20).reshape(4, 5), True), |
| 136 | + # 2D arrays where first dim != 4 -> NO transpose |
| 137 | + (np.arange(12).reshape(3, 4), False), |
| 138 | + (np.arange(20).reshape(5, 4), False), |
| 139 | + # List of lists |
| 140 | + ([[1, 2, 3, 4], [5, 6, 7, 8]], False), |
| 141 | + ([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], True), |
| 142 | + ], |
| 143 | +) |
| 144 | +def test_format_gradients_basic(value, expect_transpose): |
| 145 | + obtained = format_gradients(value) |
| 146 | + |
| 147 | + assert isinstance(obtained, np.ndarray) |
| 148 | + if expect_transpose: |
| 149 | + assert obtained.shape == np.asarray(value).T.shape |
| 150 | + assert np.allclose(obtained, np.asarray(value).T) |
| 151 | + else: |
| 152 | + assert obtained.shape == np.asarray(value).shape |
| 153 | + assert np.allclose(obtained, np.asarray(value)) |
| 154 | + |
| 155 | + |
| 156 | +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) |
| 157 | +def test_gradients_absence_error(setup_random_uniform_spatial_data): |
| 158 | + data, affine = setup_random_uniform_spatial_data |
| 159 | + with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG): |
| 160 | + DWI(dataobj=data, affine=affine) |
| 161 | + |
| 162 | + |
91 | 163 | @pytest.mark.random_gtab_data(10, (1000, 2000), 2) |
92 | 164 | @pytest.mark.random_dwi_data(50, (34, 36, 24), True) |
93 | 165 | @pytest.mark.parametrize("row_major_gradients", (False, True)) |
|
0 commit comments