Skip to content

Commit 79dcf6c

Browse files
authored
Merge branch 'master' into add_prephasors
2 parents 20b4b1c + fd6f970 commit 79dcf6c

File tree

4 files changed

+82
-16
lines changed

4 files changed

+82
-16
lines changed

src/mrinufft/io/nsp.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def read_trajectory(
368368
raster_time: float = DEFAULT_RASTER_TIME,
369369
read_shots: bool = False,
370370
normalize_factor: float = KMAX,
371+
pre_skip: int = 0,
371372
):
372373
"""Get k-space locations from gradient file.
373374
@@ -389,9 +390,13 @@ def read_trajectory(
389390
read_shots : bool, optional
390391
Whether in read shots configuration which accepts an extra
391392
point at end, by default False
392-
normalize : float, optional
393+
normalize_factor : float, optional
393394
Whether to normalize the k-space locations, by default 0.5
394395
When None, normalization is not done.
396+
pre_skip: int, optional
397+
Number of samples to skip from the start of each shot,
398+
by default 0. This is useful when we want to avoid artifacts
399+
from ADC switching in UTE sequences.
395400
396401
Returns
397402
-------
@@ -522,6 +527,15 @@ def read_trajectory(
522527
if normalize_factor is not None:
523528
Kmax = img_size / 2 / fov
524529
kspace_loc = kspace_loc / Kmax * normalize_factor
530+
if pre_skip > 0:
531+
if pre_skip >= num_samples_per_shot:
532+
raise ValueError(
533+
"skip_first_Nsamples should be less than num_adc_samples"
534+
)
535+
oversample_factor = num_adc_samples / num_samples_per_shot
536+
skip_samples = pre_skip * int(oversample_factor)
537+
kspace_loc = kspace_loc[:, skip_samples:]
538+
params["num_adc_samples"] = num_adc_samples - skip_samples
525539
return kspace_loc, params
526540

527541

@@ -532,6 +546,7 @@ def read_arbgrad_rawdat(
532546
squeeze: bool = True,
533547
slice_num: int | None = None,
534548
contrast_num: int | None = None,
549+
pre_skip: int = 0,
535550
data_type: str = "ARBGRAD_VE11C",
536551
): # pragma: no cover
537552
"""Read raw data from a Siemens MRI file.
@@ -550,6 +565,10 @@ def read_arbgrad_rawdat(
550565
The slice to read, by default None. This applies for 2D data.
551566
contrast_num: int, optional
552567
The contrast to read, by default None.
568+
pre_skip : int, optional
569+
Number of samples to skip from the start of each shot,
570+
by default 0. This is useful when we want to avoid artifacts
571+
from ADC switching in UTE sequences.
553572
data_type : str, optional
554573
The type of data to read, by default 'ARBGRAD_VE11C'.
555574
@@ -598,4 +617,12 @@ def read_arbgrad_rawdat(
598617
"Phoenix", ("sFastImaging", "lTurboFactor")
599618
)[0]
600619
hdr["type"] = "ARBGRAD_MP2RAGE"
620+
if pre_skip > 0:
621+
samples_to_skip = int(hdr["oversampling_factor"] * pre_skip)
622+
if samples_to_skip >= hdr["n_adc_samples"]:
623+
raise ValueError(
624+
"Samples to skip should be less than n_samples in the data"
625+
)
626+
data = data[:, :, samples_to_skip:]
627+
hdr["n_adc_samples"] -= samples_to_skip
601628
return data, hdr

src/mrinufft/operators/interfaces/cufinufft.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self._kz = cp.array(samples[:, 2], copy=False) if self.ndim == 3 else None
6565
for i in [1, 2]:
6666
self._make_plan(i, **kwargs)
67-
self._set_pts(i, samples)
67+
self._set_pts(i)
6868

6969
@property
7070
def dtype(self):
@@ -84,13 +84,15 @@ def _make_plan(self, typ, **kwargs):
8484
**kwargs,
8585
)
8686

87-
def _set_pts(self, typ, samples):
87+
def _set_kxyz(self, samples):
88+
self._kx.set(samples[:, 0])
89+
self._ky.set(samples[:, 1])
90+
if self.ndim == 3:
91+
self._kz.set(samples[:, 2])
92+
93+
def _set_pts(self, typ):
8894
plan = self.grad_plan if typ == "grad" else self.plans[typ]
89-
plan.setpts(
90-
cp.array(samples[:, 0], copy=False),
91-
cp.array(samples[:, 1], copy=False),
92-
cp.array(samples[:, 2], copy=False) if self.ndim == 3 else None,
93-
)
95+
plan.setpts(self._kx, self._ky, self._kz)
9496

9597
def _destroy_plan(self, typ):
9698
if self.plans[typ] is not None:
@@ -272,10 +274,11 @@ def samples(self, new_samples):
272274
np.float32, copy=False
273275
)
274276
)
277+
self.raw_op._set_kxyz(self._samples)
275278
for typ in [1, 2, "grad"]:
276279
if typ == "grad" and not self._grad_wrt_traj:
277280
continue
278-
self.raw_op._set_pts(typ, samples=self._samples)
281+
self.raw_op._set_pts(typ)
279282
self.compute_density(self._density_method)
280283

281284
@FourierOperatorBase.density.setter
@@ -808,7 +811,7 @@ def _make_plan_grad(self, **kwargs):
808811
isign=1,
809812
**kwargs,
810813
)
811-
self.raw_op._set_pts(typ="grad", samples=self.samples)
814+
self.raw_op._set_pts(typ="grad")
812815

813816
def get_lipschitz_cst(self, max_iter=10, **kwargs):
814817
"""Return the Lipschitz constant of the operator.

src/mrinufft/operators/off_resonance.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ class MRIFourierCorrected(FourierOperatorBase):
198198
Must have same shape as ``b0_map``.
199199
The default is ``None`` (purely imaginary field).
200200
Also supports Cupy arrays and Torch tensors.
201+
negate: bool, optional, default=False
202+
If True, negate the field map. Useful for matching the convention of
203+
your field map generation.
201204
backend: str, optional
202205
The backend to use for computations. Either 'cpu', 'gpu' or 'torch'.
203206
The default is `cpu`.
@@ -221,6 +224,7 @@ def __init__(
221224
r2star_map=None,
222225
B=None,
223226
tl=None,
227+
negate=False,
224228
backend="cpu",
225229
):
226230
if backend == "gpu" and not CUPY_AVAILABLE:
@@ -235,7 +239,9 @@ def __init__(
235239
raise ValueError("Unsupported backend.")
236240

237241
self._fourier_op = fourier_op
238-
242+
if not isinstance(negate, bool):
243+
raise ValueError("negate must be a boolean value.")
244+
self.isign = -1 if negate else 1
239245
self.n_coils = fourier_op.n_coils
240246
self.shape = fourier_op.shape
241247
self.smaps = fourier_op.smaps
@@ -275,7 +281,7 @@ def __init__(
275281
self.C = None
276282
self.field_map = field_map
277283
else:
278-
self.C = _get_spatial_coefficients(field_map, self.tl)
284+
self.C = _get_spatial_coefficients(field_map, self.tl, isign=self.isign)
279285
self.field_map = None
280286

281287
def op(self, data, *args):
@@ -377,11 +383,10 @@ def _get_complex_fieldmap(b0_map, r2star_map=None):
377383
return field_map
378384

379385

380-
def _get_spatial_coefficients(field_map, tl):
386+
def _get_spatial_coefficients(field_map, tl, isign=-1):
381387
xp = get_array_module(field_map)
382-
383388
# get spatial coeffs
384-
C = xp.exp(-tl * field_map[..., None])
389+
C = xp.exp(isign * tl * field_map[..., None])
385390
C = C[None, ...].swapaxes(0, -1)[
386391
..., 0
387392
] # (..., n_time_segments) -> (n_time_segments, ...)

tests/test_offres_exp_approx.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mrinufft._array_compat import CUPY_AVAILABLE
1313
from mrinufft._utils import get_array_module
1414
from mrinufft.operators.off_resonance import MRIFourierCorrected
15-
15+
from mrinufft import get_operator
1616

1717
from helpers import to_interface, assert_allclose
1818
from helpers.factories import _param_array_interface
@@ -106,3 +106,34 @@ def test_zmap_coeff(zmap, mask, array_interface):
106106
)
107107
actual = calculate_approx_offresonance_term(B, C)
108108
assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface)
109+
110+
111+
def test_b0_map_upsampling_warns_and_matches_shape():
112+
"""Test that MRIFourierCorrected upscales the b0_map and warns if shape mismatch exists."""
113+
114+
shape_target = (16, 16, 16)
115+
b0_shape = (8, 8, 8)
116+
117+
b0_map = np.ones(b0_shape, dtype=np.float32)
118+
kspace = np.zeros((10, 3), dtype=np.float32)
119+
smaps = np.ones((1, *shape_target), dtype=np.complex64)
120+
readout_time = np.ones(10, dtype=np.float32)
121+
122+
nufft = get_operator("finufft")(
123+
samples=kspace,
124+
shape=shape_target,
125+
n_coils=1,
126+
smaps=smaps,
127+
density=False,
128+
)
129+
130+
with pytest.warns(UserWarning):
131+
op = MRIFourierCorrected(
132+
nufft,
133+
b0_map=b0_map,
134+
readout_time=readout_time,
135+
)
136+
137+
# check that no exception is raised and internal shape matches
138+
assert op.B.shape[1] == len(readout_time)
139+
assert op.shape == shape_target

0 commit comments

Comments
 (0)