|
22 | 22 | # |
23 | 23 | """Unit tests exercising the dMRI data structure.""" |
24 | 24 |
|
| 25 | +import re |
25 | 26 | from pathlib import Path |
| 27 | +from string import Formatter |
26 | 28 |
|
27 | 29 | import nibabel as nb |
28 | 30 | import numpy as np |
29 | 31 | import pytest |
30 | 32 |
|
31 | 33 | from nifreeze.data import load |
32 | | -from nifreeze.data.dmri import DWI, find_shelling_scheme, from_nii, transform_fsl_bvec |
| 34 | +from nifreeze.data.dmri import ( |
| 35 | + DWI, |
| 36 | + GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG, |
| 37 | + GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG, |
| 38 | + GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG, |
| 39 | + find_shelling_scheme, |
| 40 | + from_nii, |
| 41 | + normalize_gradients, |
| 42 | + transform_fsl_bvec, |
| 43 | +) |
33 | 44 | from nifreeze.utils.ndimage import load_api |
34 | 45 |
|
35 | 46 |
|
| 47 | +def _template_has_field(template: str, field_name: str | None = None) -> bool: |
| 48 | + """Return True if `template` contains a format field. |
| 49 | + If `field_name` is provided, return True only if that named field appears. |
| 50 | +
|
| 51 | + This uses Formatter.parse() so it recognizes real format fields and |
| 52 | + ignores literal substrings that merely look like "{shape}". |
| 53 | + """ |
| 54 | + formatter = Formatter() |
| 55 | + for _literal_text, field, _format_spec, _conversion in formatter.parse(template): |
| 56 | + if field is None: |
| 57 | + # no field in this segment |
| 58 | + continue |
| 59 | + # field can be '' (positional {}), 'shape', or complex like 'shape[0]' or 'obj.attr' |
| 60 | + if field_name is None: |
| 61 | + return True |
| 62 | + # Compare the base name before any attribute/indexing syntax |
| 63 | + base = field.split(".", 1)[0].split("[", 1)[0] |
| 64 | + if base == field_name: |
| 65 | + return True |
| 66 | + return False |
| 67 | + |
| 68 | + |
36 | 69 | def _dwi_data_to_nifti( |
37 | 70 | dwi_dataobj, |
38 | 71 | affine, |
@@ -959,3 +992,145 @@ def test_transform_fsl_bvec(b_ijk, zooms, flips, axis_order, origin, angles): |
959 | 992 | f"Expected {rotated_b_ijk}, got {test_b_ijk} for b_ijk={b_ijk}, " |
960 | 993 | f"zooms={zooms}, origin={origin}, angles={angles}" |
961 | 994 | ) |
| 995 | + |
| 996 | + |
| 997 | +@pytest.mark.parametrize( |
| 998 | + "shape, expected_msg_template", |
| 999 | + [ |
| 1000 | + # 1D but wrong length |
| 1001 | + ((4,), GRADIENT_NORMALIZATION_LENGTH_ERROR_MSG), |
| 1002 | + # ndim != 1 and != 2 |
| 1003 | + ((2, 2, 2), GRADIENT_NORMALIZATION_SHAPE_ERROR_MSG), |
| 1004 | + # 2D but unrecognized shape (neither Nx3/Nx4 nor 3xN/4xN) |
| 1005 | + ((2, 2), GRADIENT_NORMALIZATION_UNRECOGNIZED_SHAPE_ERROR_MSG), |
| 1006 | + ], |
| 1007 | +) |
| 1008 | +def test_normalize_gradients_exceptions(shape, expected_msg_template): |
| 1009 | + arr = np.zeros(shape, dtype=float) |
| 1010 | + if _template_has_field(expected_msg_template, "shape"): |
| 1011 | + expected_msg = expected_msg_template.format(shape=shape) |
| 1012 | + else: |
| 1013 | + expected_msg = expected_msg_template |
| 1014 | + |
| 1015 | + with pytest.raises(ValueError, match=re.escape(expected_msg)): |
| 1016 | + normalize_gradients(arr) |
| 1017 | + |
| 1018 | + |
| 1019 | +@pytest.mark.parametrize( |
| 1020 | + "arr, expected", |
| 1021 | + [ |
| 1022 | + # Nx3: rows are b-vectors (e.g., [gx gy gz]) |
| 1023 | + ( |
| 1024 | + np.array([[1, 0, 0], [0, 2, 0], [0, 0, 0]], float), |
| 1025 | + np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], float), |
| 1026 | + ), |
| 1027 | + ( |
| 1028 | + np.array([[2, 3, 1], [7, 5, 6], [0, 0, 0]], float), |
| 1029 | + np.array( |
| 1030 | + [ |
| 1031 | + [2.0 / np.sqrt(14), 3.0 / np.sqrt(14), 1.0 / np.sqrt(14)], |
| 1032 | + [7.0 / np.sqrt(110), 5.0 / np.sqrt(110), 6.0 / np.sqrt(110)], |
| 1033 | + [0.0, 0.0, 0.0], |
| 1034 | + ], |
| 1035 | + float, |
| 1036 | + ), |
| 1037 | + ), |
| 1038 | + # Nx4: first 3 columns are b-vectors (e.g., [gx gy gz b]) |
| 1039 | + ( |
| 1040 | + np.array([[1, 0, 0, 0], [0, 2, 0, 1000], [0, 0, 0, 0]], float), |
| 1041 | + np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1000.0], [0.0, 0.0, 0.0, 0.0]], float), |
| 1042 | + ), |
| 1043 | + ( |
| 1044 | + np.array([[0, 0, 0, 0], [1, 0, 2, 1000], [1, 2, 1, 1000]], float), |
| 1045 | + np.array( |
| 1046 | + [ |
| 1047 | + [0, 0, 0, 0], |
| 1048 | + [1.0 / np.sqrt(5), 0, 2.0 / np.sqrt(5), 1000], |
| 1049 | + [1.0 / np.sqrt(6), 2.0 / np.sqrt(6), 1.0 / np.sqrt(6), 1000], |
| 1050 | + ], |
| 1051 | + float, |
| 1052 | + ), |
| 1053 | + ), |
| 1054 | + ( |
| 1055 | + np.array([[4.0, 2.0, 1.0, 250.0]], float), |
| 1056 | + np.array([[4.0 / np.sqrt(21), 2.0 / np.sqrt(21), 1.0 / np.sqrt(21), 250.0]], float), |
| 1057 | + ), |
| 1058 | + # 3xN: columns are b-vectors (e.g., [gx gy gz].T) |
| 1059 | + ( |
| 1060 | + np.array([[1, 0], [0, 2], [0, 0]], float), |
| 1061 | + np.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]], float), |
| 1062 | + ), |
| 1063 | + ( |
| 1064 | + np.array([[8.0, 0.0], [1.0, 0.0], [6.0, 0.0]], float), |
| 1065 | + np.array( |
| 1066 | + [[8.0 / np.sqrt(101), 0.0], [1.0 / np.sqrt(101), 0.0], [6.0 / np.sqrt(101), 0.0]], |
| 1067 | + float, |
| 1068 | + ), |
| 1069 | + ), |
| 1070 | + # 4xN: first 3 rows are b-vectors (e.g., [gx gy gz b].T) |
| 1071 | + ( |
| 1072 | + np.array([[1, 0], [0, 2], [0, 0], [0, 1000]], float), |
| 1073 | + np.array([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [0.0, 1000.0]], float), |
| 1074 | + ), |
| 1075 | + ( |
| 1076 | + np.array([[6.0, 0.0], [8.0, 0.0], [0.0, 0.0], [5.0, 200.0]], float), |
| 1077 | + np.array([[0.6, 0.0], [0.8, 0.0], [0.0, 0.0], [5.0, 200.0]], float), |
| 1078 | + ), |
| 1079 | + # 1D single vector |
| 1080 | + (np.array([3, 0, 0], float), np.array([1.0, 0.0, 0.0], float)), |
| 1081 | + (np.array([3.0, 4.0, 0.0], float), np.array([0.6, 0.8, 0.0], float)), |
| 1082 | + ], |
| 1083 | +) |
| 1084 | +def test_normalize_gradients_shapes(arr, expected): |
| 1085 | + """Normalize several common bvec layouts and compare to expected output.""" |
| 1086 | + obtained = normalize_gradients(arr) # default copy=True |
| 1087 | + |
| 1088 | + assert obtained.shape == expected.shape |
| 1089 | + assert np.allclose(obtained, expected) |
| 1090 | + |
| 1091 | + |
| 1092 | +@pytest.mark.parametrize( |
| 1093 | + "arr, idx_check, expected_row", |
| 1094 | + [ |
| 1095 | + # Nx3 in-place: ensure modification and returned object identity |
| 1096 | + (np.array([[0, 3, 0], [0, 0, 0]], float), (0,), np.array([0.0, 1.0, 0.0])), |
| 1097 | + # 1D single vector in-place |
| 1098 | + (np.array([3, 0, 0], float), None, np.array([1.0, 0.0, 0.0])), |
| 1099 | + ], |
| 1100 | +) |
| 1101 | +def test_normalize_gradients_inplace(arr, idx_check, expected_row): |
| 1102 | + """ |
| 1103 | + Ensure copy=False modifies the provided ndarray in-place and that returned |
| 1104 | + object is the same when appropriate. For 1D arrays the returned object |
| 1105 | + should be the same object when copy=False. |
| 1106 | + """ |
| 1107 | + arr_copy = arr.copy() |
| 1108 | + obtained = normalize_gradients(arr_copy, copy=False) |
| 1109 | + |
| 1110 | + # returned object must be the exact same ndarray when copy=False |
| 1111 | + assert obtained is arr_copy |
| 1112 | + |
| 1113 | + if idx_check is None: |
| 1114 | + # 1D vector: compare whole array |
| 1115 | + assert np.allclose(arr_copy, expected_row) |
| 1116 | + else: |
| 1117 | + # For multi-row arrays, check the indicated row(s) |
| 1118 | + # idx_check is a tuple of row indices to check (here only first row) |
| 1119 | + for i, expected in zip(idx_check, [expected_row]): |
| 1120 | + assert np.allclose(arr_copy[i], expected) |
| 1121 | + |
| 1122 | + |
| 1123 | +def test_normalize_gradients_zero_vectors_preserved_and_norms(): |
| 1124 | + """Check that near-zero vectors are left unchanged and non-zero are unit.""" |
| 1125 | + a = np.array([[1e-12, 0, 0], [0, 2, 0], [0, 0, 0]], float) |
| 1126 | + eps = 1e-8 |
| 1127 | + obtained = normalize_gradients(a, eps=eps) |
| 1128 | + |
| 1129 | + # First row is near-zero with norm < eps -> preserved (close to original) |
| 1130 | + assert np.allclose(obtained[0], a[0]) |
| 1131 | + |
| 1132 | + # Second row normalized to unit length |
| 1133 | + assert np.allclose(np.linalg.norm(obtained[1]), 1.0) |
| 1134 | + |
| 1135 | + # Last row is exactly zero and preserved |
| 1136 | + assert np.allclose(obtained[2], np.zeros(3)) |
0 commit comments