Skip to content

Commit 3a61bc9

Browse files
committed
feat: update test case for ORC factorization.
1 parent e23007b commit 3a61bc9

File tree

3 files changed

+91
-140
lines changed

3 files changed

+91
-140
lines changed

tests/case_fieldmaps.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

tests/helpers/factories.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def to_interface(data, interface):
5454

5555
def from_interface(data, interface):
5656
"""Get DATA from INTERFACE as a numpy array."""
57+
if isinstance(data, np.ndarray):
58+
return data
5759
if interface == "cupy":
5860
return data.get()
5961
elif "torch" in interface:
@@ -86,6 +88,20 @@ def from_interface(data, interface):
8688
],
8789
)
8890

91+
_param_array_interface_np_cp = pytest.mark.parametrize(
92+
"array_interface",
93+
[
94+
"numpy",
95+
pytest.param(
96+
"cupy",
97+
marks=pytest.mark.skipif(
98+
not CUPY_AVAILABLE,
99+
reason="cupy not available",
100+
),
101+
),
102+
],
103+
)
104+
89105

90106
def param_array_interface(func):
91107
"""Parametrize the array interfaces for a test."""

tests/test_offres_exp_approx.py

Lines changed: 75 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,103 @@
11
"""Test off-resonance spatial coefficient and temporal interpolator estimation."""
22

3-
import math
4-
3+
from mrinufft.extras import get_orc_factorization, get_complex_fieldmap_rad
54
import numpy as np
6-
5+
import numpy.testing as npt
76
import pytest
8-
from pytest_cases import parametrize_with_cases
7+
from pytest_cases import parametrize_with_cases, parametrize
98

109

11-
import mrinufft
12-
from mrinufft._array_compat import CUPY_AVAILABLE
13-
from mrinufft._utils import get_array_module
1410
from mrinufft.operators.off_resonance import MRIFourierCorrected
1511
from mrinufft import get_operator
12+
from mrinufft.extras import make_b0map, make_t2smap
1613

17-
from helpers import to_interface, assert_allclose
18-
from helpers.factories import _param_array_interface
19-
from case_fieldmaps import CasesB0maps, CasesZmaps
14+
from helpers import to_interface
15+
from helpers.factories import _param_array_interface_np_cp, from_interface
2016

2117

22-
def calculate_true_offresonance_term(fieldmap, t, array_interface):
23-
"""Calculate non-approximate off-resonance modulation term."""
24-
fieldmap = to_interface(fieldmap, array_interface)
25-
t = to_interface(t, array_interface)
18+
class CasesB0maps:
19+
"""B0 field maps cases we want to test.
2620
27-
xp = get_array_module(fieldmap)
28-
arg = t * fieldmap[..., None]
29-
arg = arg[None, ...].swapaxes(0, -1)[..., 0]
30-
return xp.exp(-arg)
21+
Each case return a field map and the binary spatial support of the object.
22+
"""
3123

24+
def case_real2D(self, N=64, b0_range=(-300, 300)):
25+
"""Create a real (B0 only) 2D field map."""
26+
b0_map, mask = make_b0map(2 * [N])
27+
return b0_map, None, mask
3228

33-
def calculate_approx_offresonance_term(B, C):
34-
"""Calculate approximate off-resonance modulation term."""
35-
field_term = 0.0
36-
for n in range(B.shape[0]):
37-
tmp = B[n] * C[n][..., None]
38-
tmp = tmp[None, ...].swapaxes(0, -1)[..., 0]
39-
field_term += tmp
40-
return field_term
29+
# def case_real3D(self, N=32, b0range=(-300, 300)):
30+
# """Create a real (B0 only) 3D field map."""
31+
# b0_map, mask = make_b0map(3 * [N], b0range)
32+
# return b0_map, None, mask
4133

34+
def case_complex2D(self, N=64, b0range=(-300, 300), t2svalue=15.0):
35+
"""Create a complex (R2* + 1j * B0) 2D field map."""
36+
# Generate real and imaginary parts
37+
t2s_map, _ = make_t2smap(2 * [N], t2svalue)
38+
b0_map, mask = make_b0map(2 * [N], b0range)
4239

43-
@_param_array_interface
44-
@parametrize_with_cases("b0map, mask", cases=CasesB0maps)
45-
def test_b0map_coeff(b0map, mask, array_interface):
46-
"""Test exponential approximation for B0 field only."""
47-
if array_interface == "torch-gpu" and not CUPY_AVAILABLE:
48-
pytest.skip("GPU computations requires cupy")
40+
# Convert T2* map to R2* map
41+
t2s_map = t2s_map * 1e-3 # ms -> s
42+
r2s_map = 1.0 / (t2s_map + 1e-9) # Hz
43+
r2s_map = mask * r2s_map
4944

50-
# Generate readout times
51-
tread = np.linspace(0.0, 5e-3, 501, dtype=np.float32)
45+
return b0_map, r2s_map, mask
5246

53-
# Generate coefficients
54-
B, tl = mrinufft.get_interpolators_from_fieldmap(
55-
to_interface(b0map, array_interface), tread, mask=mask, n_time_segments=100
56-
)
47+
# def case_complex3D(self, N=32, b0range=(-300, 300), t2svalue=15.0):
48+
# """Create a complex (R2* + 1j * B0) 3D field map."""
49+
# # Generate real and imaginary parts
50+
# t2s_map, _ = make_t2smap(3 * [N], t2svalue)
51+
# b0_map, mask = make_b0map(3 * [N], b0range)
5752

58-
# Calculate spatial coefficients
59-
C = MRIFourierCorrected.get_spatial_coefficients(
60-
to_interface(2 * math.pi * 1j * b0map, array_interface), tl
61-
)
53+
# # Convert T2* map to R2* map
54+
# t2s_map = t2s_map * 1e-3 # ms -> s
55+
# r2s_map = 1.0 / (t2s_map + 1e-9) # Hz
56+
# r2s_map = mask * r2s_map
57+
# return b0_map, r2s_map, mask
6258

63-
# Assert properties
64-
assert B.shape == (100, 501)
65-
assert C.shape == (100, *b0map.shape)
66-
67-
# Correct approximation
68-
expected = calculate_true_offresonance_term(
69-
0 + 2 * math.pi * 1j * b0map, tread, array_interface
70-
)
71-
actual = calculate_approx_offresonance_term(B, C)
72-
assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface)
73-
74-
75-
@_param_array_interface
76-
@parametrize_with_cases("zmap, mask", cases=CasesZmaps)
77-
def test_zmap_coeff(zmap, mask, array_interface):
78-
"""Test exponential approximation for complex Z = R2* + 1j *B0 field."""
79-
if array_interface == "torch-gpu" and CUPY_AVAILABLE is False:
80-
pytest.skip("GPU computations requires cupy")
8159

60+
@_param_array_interface_np_cp
61+
@parametrize_with_cases("b0_map, r2s_map, mask", cases=CasesB0maps)
62+
@parametrize("method", ["svd-full", "mti", "mfi"])
63+
@parametrize("L", [40, -1])
64+
def test_b0map_coeff(b0_map, r2s_map, mask, method, L, array_interface):
65+
"""Test exponential approximation for B0 field only."""
8266
# Generate readout times
83-
tread = np.linspace(0.0, 5e-3, 501, dtype=np.float32)
84-
85-
# Generate coefficients
86-
B, tl = mrinufft.get_interpolators_from_fieldmap(
87-
to_interface(zmap.imag, array_interface),
88-
tread,
89-
mask=mask,
90-
r2star_map=to_interface(zmap.real, array_interface),
91-
n_time_segments=100,
92-
)
93-
94-
# Calculate spatial coefficients
95-
C = MRIFourierCorrected.get_spatial_coefficients(
96-
to_interface(2 * math.pi * zmap, array_interface), tl
67+
Nt = 400
68+
tread = np.linspace(0.0, 5e-3, Nt, dtype=np.float32)
69+
70+
cpx_fieldmap = get_complex_fieldmap_rad(b0_map, r2s_map).astype(np.complex64)
71+
72+
E_full = np.exp(np.outer(tread, cpx_fieldmap[mask]))
73+
74+
kwargs = {}
75+
if method == "svd-full": # Truncated SVD is flacky (esp. for cupy)
76+
kwargs["partial_svd"] = False
77+
method = "svd"
78+
B, C, _ = get_orc_factorization(method)(
79+
to_interface(cpx_fieldmap, array_interface),
80+
to_interface(tread, array_interface),
81+
to_interface(mask, array_interface),
82+
L=L,
83+
lazy=False,
84+
n_bins=4096,
85+
**kwargs,
9786
)
9887

88+
if L == -1:
89+
L = B.shape[1]
90+
print(L)
9991
# Assert properties
100-
assert B.shape == (100, 501)
101-
assert C.shape == (100, *zmap.shape)
92+
assert B.shape == (Nt, L)
93+
assert C.shape == (L, *b0_map.shape)
10294

103-
# Correct approximation
104-
expected = calculate_true_offresonance_term(
105-
2 * math.pi * zmap, tread, array_interface
106-
)
107-
actual = calculate_approx_offresonance_term(B, C)
108-
assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface)
95+
# Check that the approximation match the full matrix.
96+
B = from_interface(B, array_interface)
97+
C = from_interface(C, array_interface)
98+
E2 = B @ C[:, mask]
99+
# TODO get closer bound somehow ?
100+
npt.assert_allclose(E2, E_full, atol=5e-3, rtol=5e-3)
109101

110102

111103
def test_b0_map_upsampling_warns_and_matches_shape():
@@ -134,5 +126,5 @@ def test_b0_map_upsampling_warns_and_matches_shape():
134126
)
135127

136128
# check that no exception is raised and internal shape matches
137-
assert op.B.shape[1] == len(readout_time)
129+
assert op.B.shape[0] == len(readout_time)
138130
assert op.shape == shape_target

0 commit comments

Comments
 (0)