|
| 1 | +"""Test that the ORC NUFFT approximates the Conjugate Phase (CP) expression.""" |
| 2 | + |
| 3 | +from mrinufft.extras.field_map import get_complex_fieldmap_rad |
| 4 | +from mrinufft.operators.interfaces.nudft_numpy import get_fourier_matrix |
| 5 | +import numpy as np |
| 6 | +import numpy.testing as npt |
| 7 | + |
| 8 | +from pytest_cases import parametrize_with_cases, parametrize, fixture |
| 9 | + |
| 10 | +from helpers import image_from_op, kspace_from_op, assert_correlate |
| 11 | + |
| 12 | +from case_trajectories import CasesTrajectories |
| 13 | + |
| 14 | +from mrinufft import get_operator |
| 15 | +from mrinufft.extras.field_map import make_b0map, make_t2smap |
| 16 | + |
| 17 | + |
| 18 | +def get_extended_fourier_matrix(ktraj, shape, cpx_fieldmap, readout_time): |
| 19 | + """Generate the extended fourier matrix with off-resonnance. |
| 20 | +
|
| 21 | + For test purposes only. |
| 22 | + """ |
| 23 | + base_fourier = get_fourier_matrix(ktraj, shape, normalize=True) |
| 24 | + off_grid = np.outer(readout_time, cpx_fieldmap.ravel()) |
| 25 | + base_fourier *= np.exp(off_grid).astype(np.complex64) |
| 26 | + return base_fourier |
| 27 | + |
| 28 | + |
| 29 | +@fixture(scope="module") |
| 30 | +@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories.case_random2D) |
| 31 | +@parametrize("backend", ["finufft", "cufinufft", "gpunufft"]) |
| 32 | +def operator(kspace_locs, shape, backend): |
| 33 | + """Create an operator with off resonance mapping support.""" |
| 34 | + return get_operator(backend)( |
| 35 | + kspace_locs, shape, n_coils=1, n_batchs=1, density=False, squeeze_dims=True |
| 36 | + ) |
| 37 | + |
| 38 | + |
| 39 | +@fixture(scope="module") |
| 40 | +def orc_info(operator): |
| 41 | + """Augment the operator to use B0 setup.""" |
| 42 | + b0_map, mask = make_b0map(operator.shape, b0range=(-300, 300)) |
| 43 | + # t2s_map, _ = make_t2smap(operator.shape, t2svalue=15) |
| 44 | + # # Convert T2* map to R2* map |
| 45 | + # t2s_map = t2s_map * 1e-3 # ms -> s |
| 46 | + # r2s_map = 1.0 / (t2s_map + 1e-9) # Hz |
| 47 | + # r2s_map = mask * r2s_map |
| 48 | + r2s_map = None |
| 49 | + |
| 50 | + cpx_fieldmap = get_complex_fieldmap_rad(b0_map, r2s_map) |
| 51 | + readout_time = np.linspace(0, 5e-2, len(operator.samples), dtype=np.float32) |
| 52 | + cp_matrix = get_extended_fourier_matrix( |
| 53 | + operator.samples, operator.shape, cpx_fieldmap, readout_time |
| 54 | + ) |
| 55 | + |
| 56 | + orc_nufft = operator.with_off_resonance_correction( |
| 57 | + b0_map=b0_map, |
| 58 | + r2star_map=r2s_map, |
| 59 | + mask=mask, |
| 60 | + readout_time=readout_time, |
| 61 | + ) |
| 62 | + |
| 63 | + return orc_nufft, cp_matrix |
| 64 | + |
| 65 | + |
| 66 | +@fixture(scope="module") |
| 67 | +def image_data(operator): |
| 68 | + """Generate a random image. Remains constant for the module.""" |
| 69 | + return image_from_op(operator) |
| 70 | + |
| 71 | + |
| 72 | +@fixture(scope="module") |
| 73 | +def kspace_data(operator): |
| 74 | + """Generate a random kspace. Remains constant for the module.""" |
| 75 | + return kspace_from_op(operator) |
| 76 | + |
| 77 | + |
| 78 | +def test_orc_forward( |
| 79 | + orc_info, |
| 80 | + image_data, |
| 81 | +): |
| 82 | + """Test that the forward approximation works.""" |
| 83 | + orc_nufft, ext_mat = orc_info |
| 84 | + ksp = orc_nufft.op(image_data) |
| 85 | + ksp_ideal = ext_mat @ image_data.ravel() |
| 86 | + |
| 87 | + assert_correlate(ksp.squeeze(), ksp_ideal) |
| 88 | + |
| 89 | + |
| 90 | +def test_orc_adjoint(orc_info, kspace_data): |
| 91 | + """Test taht the adjoint approximation works.""" |
| 92 | + orc_nufft, ext_mat = orc_info |
| 93 | + img = orc_nufft.adj_op(kspace_data) |
| 94 | + img_ideal = ext_mat.conj().T @ kspace_data.ravel() |
| 95 | + img_ideal = img_ideal.reshape(orc_nufft.shape) |
| 96 | + assert_correlate(img.squeeze(), img_ideal) |
0 commit comments