|
1 | 1 | """Test off-resonance spatial coefficient and temporal interpolator estimation.""" |
2 | 2 |
|
3 | | -import math |
4 | | - |
| 3 | +from mrinufft.extras import get_orc_factorization, get_complex_fieldmap_rad |
5 | 4 | import numpy as np |
6 | | - |
| 5 | +import numpy.testing as npt |
7 | 6 | import pytest |
8 | | -from pytest_cases import parametrize_with_cases |
| 7 | +from pytest_cases import parametrize_with_cases, parametrize |
9 | 8 |
|
10 | 9 |
|
11 | | -import mrinufft |
12 | | -from mrinufft._array_compat import CUPY_AVAILABLE |
13 | | -from mrinufft._utils import get_array_module |
14 | 10 | from mrinufft.operators.off_resonance import MRIFourierCorrected |
15 | 11 | from mrinufft import get_operator |
| 12 | +from mrinufft.extras import make_b0map, make_t2smap |
16 | 13 |
|
17 | | -from helpers import to_interface, assert_allclose |
18 | | -from helpers.factories import _param_array_interface |
19 | | -from case_fieldmaps import CasesB0maps, CasesZmaps |
| 14 | +from helpers import to_interface |
| 15 | +from helpers.factories import _param_array_interface_np_cp, from_interface |
20 | 16 |
|
21 | 17 |
|
22 | | -def calculate_true_offresonance_term(fieldmap, t, array_interface): |
23 | | - """Calculate non-approximate off-resonance modulation term.""" |
24 | | - fieldmap = to_interface(fieldmap, array_interface) |
25 | | - t = to_interface(t, array_interface) |
| 18 | +class CasesB0maps: |
| 19 | + """B0 field maps cases we want to test. |
26 | 20 |
|
27 | | - xp = get_array_module(fieldmap) |
28 | | - arg = t * fieldmap[..., None] |
29 | | - arg = arg[None, ...].swapaxes(0, -1)[..., 0] |
30 | | - return xp.exp(-arg) |
| 21 | + Each case return a field map and the binary spatial support of the object. |
| 22 | + """ |
31 | 23 |
|
| 24 | + def case_real2D(self, N=64, b0_range=(-300, 300)): |
| 25 | + """Create a real (B0 only) 2D field map.""" |
| 26 | + b0_map, mask = make_b0map(2 * [N]) |
| 27 | + return b0_map, None, mask |
32 | 28 |
|
33 | | -def calculate_approx_offresonance_term(B, C): |
34 | | - """Calculate approximate off-resonance modulation term.""" |
35 | | - field_term = 0.0 |
36 | | - for n in range(B.shape[0]): |
37 | | - tmp = B[n] * C[n][..., None] |
38 | | - tmp = tmp[None, ...].swapaxes(0, -1)[..., 0] |
39 | | - field_term += tmp |
40 | | - return field_term |
| 29 | + # def case_real3D(self, N=32, b0range=(-300, 300)): |
| 30 | + # """Create a real (B0 only) 3D field map.""" |
| 31 | + # b0_map, mask = make_b0map(3 * [N], b0range) |
| 32 | + # return b0_map, None, mask |
41 | 33 |
|
| 34 | + def case_complex2D(self, N=64, b0range=(-300, 300), t2svalue=15.0): |
| 35 | + """Create a complex (R2* + 1j * B0) 2D field map.""" |
| 36 | + # Generate real and imaginary parts |
| 37 | + t2s_map, _ = make_t2smap(2 * [N], t2svalue) |
| 38 | + b0_map, mask = make_b0map(2 * [N], b0range) |
42 | 39 |
|
43 | | -@_param_array_interface |
44 | | -@parametrize_with_cases("b0map, mask", cases=CasesB0maps) |
45 | | -def test_b0map_coeff(b0map, mask, array_interface): |
46 | | - """Test exponential approximation for B0 field only.""" |
47 | | - if array_interface == "torch-gpu" and not CUPY_AVAILABLE: |
48 | | - pytest.skip("GPU computations requires cupy") |
| 40 | + # Convert T2* map to R2* map |
| 41 | + t2s_map = t2s_map * 1e-3 # ms -> s |
| 42 | + r2s_map = 1.0 / (t2s_map + 1e-9) # Hz |
| 43 | + r2s_map = mask * r2s_map |
49 | 44 |
|
50 | | - # Generate readout times |
51 | | - tread = np.linspace(0.0, 5e-3, 501, dtype=np.float32) |
| 45 | + return b0_map, r2s_map, mask |
52 | 46 |
|
53 | | - # Generate coefficients |
54 | | - B, tl = mrinufft.get_interpolators_from_fieldmap( |
55 | | - to_interface(b0map, array_interface), tread, mask=mask, n_time_segments=100 |
56 | | - ) |
| 47 | + # def case_complex3D(self, N=32, b0range=(-300, 300), t2svalue=15.0): |
| 48 | + # """Create a complex (R2* + 1j * B0) 3D field map.""" |
| 49 | + # # Generate real and imaginary parts |
| 50 | + # t2s_map, _ = make_t2smap(3 * [N], t2svalue) |
| 51 | + # b0_map, mask = make_b0map(3 * [N], b0range) |
57 | 52 |
|
58 | | - # Calculate spatial coefficients |
59 | | - C = MRIFourierCorrected.get_spatial_coefficients( |
60 | | - to_interface(2 * math.pi * 1j * b0map, array_interface), tl |
61 | | - ) |
| 53 | + # # Convert T2* map to R2* map |
| 54 | + # t2s_map = t2s_map * 1e-3 # ms -> s |
| 55 | + # r2s_map = 1.0 / (t2s_map + 1e-9) # Hz |
| 56 | + # r2s_map = mask * r2s_map |
| 57 | + # return b0_map, r2s_map, mask |
62 | 58 |
|
63 | | - # Assert properties |
64 | | - assert B.shape == (100, 501) |
65 | | - assert C.shape == (100, *b0map.shape) |
66 | | - |
67 | | - # Correct approximation |
68 | | - expected = calculate_true_offresonance_term( |
69 | | - 0 + 2 * math.pi * 1j * b0map, tread, array_interface |
70 | | - ) |
71 | | - actual = calculate_approx_offresonance_term(B, C) |
72 | | - assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface) |
73 | | - |
74 | | - |
75 | | -@_param_array_interface |
76 | | -@parametrize_with_cases("zmap, mask", cases=CasesZmaps) |
77 | | -def test_zmap_coeff(zmap, mask, array_interface): |
78 | | - """Test exponential approximation for complex Z = R2* + 1j *B0 field.""" |
79 | | - if array_interface == "torch-gpu" and CUPY_AVAILABLE is False: |
80 | | - pytest.skip("GPU computations requires cupy") |
81 | 59 |
|
| 60 | +@_param_array_interface_np_cp |
| 61 | +@parametrize_with_cases("b0_map, r2s_map, mask", cases=CasesB0maps) |
| 62 | +@parametrize("method", ["svd-full", "mti", "mfi"]) |
| 63 | +@parametrize("L", [40, -1]) |
| 64 | +def test_b0map_coeff(b0_map, r2s_map, mask, method, L, array_interface): |
| 65 | + """Test exponential approximation for B0 field only.""" |
82 | 66 | # Generate readout times |
83 | | - tread = np.linspace(0.0, 5e-3, 501, dtype=np.float32) |
84 | | - |
85 | | - # Generate coefficients |
86 | | - B, tl = mrinufft.get_interpolators_from_fieldmap( |
87 | | - to_interface(zmap.imag, array_interface), |
88 | | - tread, |
89 | | - mask=mask, |
90 | | - r2star_map=to_interface(zmap.real, array_interface), |
91 | | - n_time_segments=100, |
92 | | - ) |
93 | | - |
94 | | - # Calculate spatial coefficients |
95 | | - C = MRIFourierCorrected.get_spatial_coefficients( |
96 | | - to_interface(2 * math.pi * zmap, array_interface), tl |
| 67 | + Nt = 400 |
| 68 | + tread = np.linspace(0.0, 5e-3, Nt, dtype=np.float32) |
| 69 | + |
| 70 | + cpx_fieldmap = get_complex_fieldmap_rad(b0_map, r2s_map).astype(np.complex64) |
| 71 | + |
| 72 | + E_full = np.exp(np.outer(tread, cpx_fieldmap[mask])) |
| 73 | + |
| 74 | + kwargs = {} |
| 75 | + if method == "svd-full": # Truncated SVD is flacky (esp. for cupy) |
| 76 | + kwargs["partial_svd"] = False |
| 77 | + method = "svd" |
| 78 | + B, C, _ = get_orc_factorization(method)( |
| 79 | + to_interface(cpx_fieldmap, array_interface), |
| 80 | + to_interface(tread, array_interface), |
| 81 | + to_interface(mask, array_interface), |
| 82 | + L=L, |
| 83 | + lazy=False, |
| 84 | + n_bins=4096, |
| 85 | + **kwargs, |
97 | 86 | ) |
98 | 87 |
|
| 88 | + if L == -1: |
| 89 | + L = B.shape[1] |
| 90 | + print(L) |
99 | 91 | # Assert properties |
100 | | - assert B.shape == (100, 501) |
101 | | - assert C.shape == (100, *zmap.shape) |
| 92 | + assert B.shape == (Nt, L) |
| 93 | + assert C.shape == (L, *b0_map.shape) |
102 | 94 |
|
103 | | - # Correct approximation |
104 | | - expected = calculate_true_offresonance_term( |
105 | | - 2 * math.pi * zmap, tread, array_interface |
106 | | - ) |
107 | | - actual = calculate_approx_offresonance_term(B, C) |
108 | | - assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface) |
| 95 | + # Check that the approximation match the full matrix. |
| 96 | + B = from_interface(B, array_interface) |
| 97 | + C = from_interface(C, array_interface) |
| 98 | + E2 = B @ C[:, mask] |
| 99 | + # TODO get closer bound somehow ? |
| 100 | + npt.assert_allclose(E2, E_full, atol=5e-3, rtol=5e-3) |
109 | 101 |
|
110 | 102 |
|
111 | 103 | def test_b0_map_upsampling_warns_and_matches_shape(): |
@@ -134,5 +126,5 @@ def test_b0_map_upsampling_warns_and_matches_shape(): |
134 | 126 | ) |
135 | 127 |
|
136 | 128 | # check that no exception is raised and internal shape matches |
137 | | - assert op.B.shape[1] == len(readout_time) |
| 129 | + assert op.B.shape[0] == len(readout_time) |
138 | 130 | assert op.shape == shape_target |
0 commit comments