Skip to content

Commit 4add9ba

Browse files
committed
feat: add integration test for off-resonance correction
1 parent 3a61bc9 commit 4add9ba

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

.github/workflows/test-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ jobs:
145145
${{ env.create_venv }}
146146
${{ env.activate_venv }}
147147
python -m pip install --upgrade pip wheel
148-
python -m pip install -e .[test]
148+
python -m pip install -e .[test,extra]
149149
python -m pip install cupy-cuda12x finufft "numpy<2.0"
150150
151151
- name: Install torch with CUDA 12.x

tests/helpers/asserts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def assert_correlate(a, b, slope=1.0, slope_err=1e-3, r_value_err=1e-3):
5454
a.flatten(), b.flatten()
5555
)
5656
abs_slope_reg = abs(slope_reg)
57+
if np.iscomplex(rvalue):
58+
rvalue = abs(rvalue)
59+
5760
if r_value_err is not None and abs(rvalue - 1) > r_value_err:
5861
raise AssertionError(
5962
f"RValue {rvalue} != 1 ± {r_value_err}\n "

tests/operators/test_orc.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Test that the ORC NUFFT approximates the Conjugate Phase (CP) expression."""
2+
3+
from mrinufft.extras.field_map import get_complex_fieldmap_rad
4+
from mrinufft.operators.interfaces.nudft_numpy import get_fourier_matrix
5+
import numpy as np
6+
import numpy.testing as npt
7+
8+
from pytest_cases import parametrize_with_cases, parametrize, fixture
9+
10+
from helpers import image_from_op, kspace_from_op, assert_correlate
11+
12+
from case_trajectories import CasesTrajectories
13+
14+
from mrinufft import get_operator
15+
from mrinufft.extras.field_map import make_b0map, make_t2smap
16+
17+
18+
def get_extended_fourier_matrix(ktraj, shape, cpx_fieldmap, readout_time):
19+
"""Generate the extended fourier matrix with off-resonnance.
20+
21+
For test purposes only.
22+
"""
23+
base_fourier = get_fourier_matrix(ktraj, shape, normalize=True)
24+
off_grid = np.outer(readout_time, cpx_fieldmap.ravel())
25+
base_fourier *= np.exp(off_grid).astype(np.complex64)
26+
return base_fourier
27+
28+
29+
@fixture(scope="module")
30+
@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories.case_random2D)
31+
@parametrize("backend", ["finufft", "cufinufft", "gpunufft"])
32+
def operator(kspace_locs, shape, backend):
33+
"""Create an operator with off resonance mapping support."""
34+
return get_operator(backend)(
35+
kspace_locs, shape, n_coils=1, n_batchs=1, density=False, squeeze_dims=True
36+
)
37+
38+
39+
@fixture(scope="module")
40+
def orc_info(operator):
41+
"""Augment the operator to use B0 setup."""
42+
b0_map, mask = make_b0map(operator.shape, b0range=(-300, 300))
43+
# t2s_map, _ = make_t2smap(operator.shape, t2svalue=15)
44+
# # Convert T2* map to R2* map
45+
# t2s_map = t2s_map * 1e-3 # ms -> s
46+
# r2s_map = 1.0 / (t2s_map + 1e-9) # Hz
47+
# r2s_map = mask * r2s_map
48+
r2s_map = None
49+
50+
cpx_fieldmap = get_complex_fieldmap_rad(b0_map, r2s_map)
51+
readout_time = np.linspace(0, 5e-2, len(operator.samples), dtype=np.float32)
52+
cp_matrix = get_extended_fourier_matrix(
53+
operator.samples, operator.shape, cpx_fieldmap, readout_time
54+
)
55+
56+
orc_nufft = operator.with_off_resonance_correction(
57+
b0_map=b0_map,
58+
r2star_map=r2s_map,
59+
mask=mask,
60+
readout_time=readout_time,
61+
)
62+
63+
return orc_nufft, cp_matrix
64+
65+
66+
@fixture(scope="module")
67+
def image_data(operator):
68+
"""Generate a random image. Remains constant for the module."""
69+
return image_from_op(operator)
70+
71+
72+
@fixture(scope="module")
73+
def kspace_data(operator):
74+
"""Generate a random kspace. Remains constant for the module."""
75+
return kspace_from_op(operator)
76+
77+
78+
def test_orc_forward(
79+
orc_info,
80+
image_data,
81+
):
82+
"""Test that the forward approximation works."""
83+
orc_nufft, ext_mat = orc_info
84+
ksp = orc_nufft.op(image_data)
85+
ksp_ideal = ext_mat @ image_data.ravel()
86+
87+
assert_correlate(ksp.squeeze(), ksp_ideal)
88+
89+
90+
def test_orc_adjoint(orc_info, kspace_data):
91+
"""Test taht the adjoint approximation works."""
92+
orc_nufft, ext_mat = orc_info
93+
img = orc_nufft.adj_op(kspace_data)
94+
img_ideal = ext_mat.conj().T @ kspace_data.ravel()
95+
img_ideal = img_ideal.reshape(orc_nufft.shape)
96+
assert_correlate(img.squeeze(), img_ideal)

0 commit comments

Comments
 (0)