Skip to content

Commit e6cdf85

Browse files
authored
Merge branch 'master' into pulseq
2 parents 1b9dc6f + 2489ef2 commit e6cdf85

File tree

20 files changed

+669
-399
lines changed

20 files changed

+669
-399
lines changed

examples/GPU/example_3d_trajectory_display.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@
1414
import os
1515
from mrinufft.trajectories.display3D import get_gridded_trajectory
1616
import mrinufft.trajectories.trajectory3D as mtt
17-
from mrinufft.trajectories.utils import Gammas
17+
from mrinufft.trajectories.utils import Gammas, Acquisition
1818
import matplotlib.pyplot as plt
1919
import numpy as np
2020

2121

2222
BACKEND = os.environ.get("MRINUFFT_BACKEND", "gpunufft")
2323

2424

25+
# %%
26+
# Acquisition parameters
27+
# ======================
28+
# Here we use acquisition defaults for the trajectory gridding.
29+
30+
acq = Acquisition.default
31+
32+
2533
# %%
2634
# Helper function to Displaying 3D Gridded Trajectories
2735
# =====================================================
@@ -57,9 +65,8 @@ def create_grid(grid_type, trajectories, traj_params, **kwargs):
5765
for i, (name, traj) in enumerate(trajectories.items()):
5866
grid = get_gridded_trajectory(
5967
traj,
60-
traj_params["img_size"],
68+
acq,
6169
grid_type=grid_type,
62-
traj_params=traj_params,
6370
backend=BACKEND,
6471
osf=2,
6572
**kwargs,

examples/operators/example_offresonance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@
8585

8686
from mrinufft import initialize_2D_spiral
8787
from mrinufft.density import voronoi
88-
from mrinufft.trajectories.utils import DEFAULT_RASTER_TIME
88+
from mrinufft.trajectories.utils import Acquisition
8989

9090
samples = initialize_2D_spiral(Nc=48, Ns=600, nb_revolutions=10)
91-
t_read = np.arange(samples.shape[1]) * DEFAULT_RASTER_TIME * 1e-3
91+
t_read = np.arange(samples.shape[1]) * Acquisition.default.raster_time
9292
t_read = np.repeat(t_read[None, ...], samples.shape[0], axis=0)
9393
density = voronoi(samples)
9494

examples/trajectories/example_3D_trajectories.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
# Internal
3232
import mrinufft as mn
33-
from mrinufft import display_2D_trajectory, display_3D_trajectory
3433

3534
# %%
3635
# Script options

examples/trajectories/example_display_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from mrinufft import display_2D_trajectory, display_3D_trajectory, displayConfig
1717
from mrinufft.trajectories import conify, initialize_2D_spiral
18+
from mrinufft.trajectories.utils import Acquisition, Hardware
1819

1920
# %%
2021
# Script options
@@ -118,3 +119,33 @@ def show_traj(traj, name, values, **kwargs):
118119
["tab:blue", "tab:orange", "tab:red"],
119120
show_constraints=True,
120121
)
122+
123+
# You can also change the values of gmax and smax in order to see how the constraint
124+
# violations change.
125+
#
126+
acqs = [
127+
Acquisition(
128+
fov=(0.256, 0.256, 0.256),
129+
img_size=(256, 256, 256),
130+
hardware=Hardware(gmax=0.04, smax=50),
131+
), # limiting slew rate to 50 T/m/s
132+
Acquisition(
133+
fov=(0.256, 0.256, 0.256),
134+
img_size=(256, 256, 256),
135+
hardware=Hardware(gmax=0.04, smax=100),
136+
), # limiting slew rate to 100 T/m/s
137+
Acquisition(
138+
fov=(0.256, 0.256, 0.256),
139+
img_size=(256, 256, 256),
140+
hardware=Hardware(gmax=0.04, smax=200),
141+
),
142+
] # limiting slew rate to 200 T/m/s
143+
144+
# you can use Acquisition as a Context Manager, like display config.
145+
# Or pass it to the display function as well.
146+
for acq in acqs:
147+
with acq:
148+
display_3D_trajectory(traj, show_constraints=True)
149+
150+
# equivalent to
151+
# display_3D_trajectory(traj, show_constraints=True, acq=acq)

examples/trajectories/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
def show_trajectory(trajectory, one_shot, figure_size):
1717
if trajectory.shape[-1] == 2:
1818
ax = display_2D_trajectory(
19-
trajectory, size=figure_size, one_shot=one_shot % trajectory.shape[0]
19+
trajectory, figsize=figure_size, one_shot=one_shot % trajectory.shape[0]
2020
)
2121
ax.set_aspect("equal")
2222
plt.tight_layout()
2323
plt.show()
2424
else:
2525
ax = display_3D_trajectory(
2626
trajectory,
27-
size=figure_size,
27+
figsize=figure_size,
2828
one_shot=one_shot % trajectory.shape[0],
2929
per_plane=False,
3030
)
@@ -49,15 +49,15 @@ def show_trajectories(
4949
if dim == "3D" and traj.shape[-1] == 3:
5050
ax = display_3D_trajectory(
5151
traj,
52-
size=subfig_size,
52+
figsize=subfig_size,
5353
one_shot=one_shot % traj.shape[0],
5454
subfigure=subfig,
5555
per_plane=False,
5656
)
5757
else:
5858
ax = display_2D_trajectory(
5959
traj[..., axes],
60-
size=subfig_size,
60+
figsize=subfig_size,
6161
one_shot=one_shot % traj.shape[0],
6262
subfigure=subfig,
6363
)

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@ description = "MRI Non-Cartesian Fourier Operators with multiple computation bac
44
authors = [{name="Pierre-antoine Comby", email="[email protected]"}]
55

66
readme = "README.md"
7-
dependencies = ["numpy", "scipy", "matplotlib", "tqdm", "joblib"]
8-
requires-python = ">=3.9"
7+
dependencies = ["numpy>=2.2", "scipy>=1.13", "matplotlib>=3.9", "tqdm", "joblib>=1.4"]
8+
requires-python = ">=3.10"
99

1010
dynamic = ["version"]
1111

1212
[project.optional-dependencies]
1313

14-
gpunufft = ["gpuNUFFT>=0.9.0", "cupy-cuda12x"]
14+
gpunufft = ["gpuNUFFT>=0.10.1", "cupy-cuda12x"]
1515

1616
torchkbnufft = ["torchkbnufft", "cupy-cuda12x"]
1717
torchkbnufft-cpu = ["torchkbnufft", "cupy-cuda12x"]
1818
torchkbnufft-gpu = ["torchkbnufft", "cupy-cuda12x"]
1919

2020
cufinufft = ["cufinufft>=2.4.0", "cupy-cuda12x"]
21-
tensorflow = ["tensorflow-mri==0.21.0", "tensorflow-probability==0.17.0", "tensorflow-io==0.27.0", "matplotlib==3.7"]
21+
tensorflow = ["tensorflow-mri>=0.22.0"]
2222
finufft = ["finufft>=2.4.0"]
2323
sigpy = ["sigpy"]
2424
pynfft = ["pynfft3"]

src/mrinufft/extras/smaps.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,21 @@
44

55
from mrinufft.density.utils import flat_traj
66
from mrinufft._utils import get_array_module
7+
from mrinufft._array_compat import with_numpy_cupy
78
from .utils import register_smaps
89
import numpy as np
10+
from numpy.typing import NDArray
11+
12+
from collections.abc import Callable
913

1014

1115
def _extract_kspace_center(
12-
kspace_data,
13-
kspace_loc,
14-
threshold=None,
15-
density=None,
16-
window_fun="ellipse",
17-
):
16+
kspace_data: NDArray,
17+
kspace_loc: NDArray,
18+
threshold: float | tuple[float, ...] = None,
19+
density: NDArray | None = None,
20+
window_fun: str | Callable[[NDArray], NDArray] = "ellipse",
21+
) -> tuple[NDArray, NDArray, NDArray | None]:
1822
r"""Extract k-space center and corresponding sampling locations.
1923
2024
The extracted center of the k-space, i.e. both the kspace locations and
@@ -81,7 +85,7 @@ def _extract_kspace_center(
8185
return data_thresholded, center_locations, dc
8286
else:
8387
if callable(window_fun):
84-
window = window_fun(center_locations)
88+
window = window_fun(kspace_loc)
8589
else:
8690
if window_fun in ["hann", "hanning", "hamming"]:
8791
radius = xp.linalg.norm(kspace_loc, axis=1)
@@ -99,16 +103,16 @@ def _extract_kspace_center(
99103
@register_smaps
100104
@flat_traj
101105
def low_frequency(
102-
traj,
103-
shape,
104-
kspace_data,
105-
backend,
106+
traj: NDArray,
107+
shape: tuple[int, ...],
108+
kspace_data: NDArray,
109+
backend: str,
106110
threshold: float | tuple[float, ...] = 0.1,
107-
density=None,
108-
window_fun: str = "ellipse",
111+
density: NDArray | None = None,
112+
window_fun: str | Callable[[NDArray], NDArray] = "ellipse",
109113
blurr_factor: int | float | tuple[float, ...] = 0.0,
110114
mask: bool = False,
111-
):
115+
) -> tuple[NDArray, NDArray]:
112116
"""
113117
Calculate low-frequency sensitivity maps.
114118
@@ -190,3 +194,61 @@ def low_frequency(
190194
SOS = np.linalg.norm(Smaps, axis=0) + 1e-10
191195
Smaps = Smaps / SOS
192196
return Smaps, SOS
197+
198+
199+
@with_numpy_cupy
200+
def coil_compression(
201+
kspace_data: NDArray,
202+
K: int | float,
203+
traj: NDArray | None = None,
204+
krad_thresh: float | None = None,
205+
) -> NDArray:
206+
"""
207+
Coil compression using principal component analysis on k-space data.
208+
209+
Parameters
210+
----------
211+
kspace_data : NDArray
212+
Multi-coil k-space data. Shape: (n_coils, n_samples).
213+
K : int or float
214+
Number of virtual coils to retain (if int), or energy threshold (if
215+
float between 0 and 1).
216+
traj : NDArray, optional
217+
Sampling trajectory. Shape: (n_samples, n_dims).
218+
krad_thresh : float, optional
219+
Relative k-space radius (as a fraction of maximum) to use for selecting
220+
the calibration region for principal component analysis. If None, use
221+
all k-space samples.
222+
223+
Returns
224+
-------
225+
NDArray
226+
Coil-compressed data. Shape: (K, n_samples) if K is int, number of
227+
retained components otherwise.
228+
"""
229+
xp = get_array_module(kspace_data)
230+
231+
if krad_thresh is not None and traj is not None:
232+
traj_rad = xp.sqrt(xp.sum(traj**2, axis=-1))
233+
center_data = kspace_data[:, traj_rad < krad_thresh * xp.max(traj)]
234+
elif krad_thresh is None:
235+
center_data = kspace_data
236+
else:
237+
raise ValueError("traj and krad_thresh must be specified.")
238+
239+
# Compute the covar matrix of selected data
240+
cov = center_data @ center_data.T.conj()
241+
w, v = xp.linalg.eigh(cov)
242+
# sort eigenvalues largest to smallest
243+
si = xp.argsort(w)[::-1]
244+
w_sorted = w[si]
245+
v_sorted = v[si]
246+
if isinstance(K, float):
247+
# retain enough components to reach energy K
248+
w_cumsum = xp.cumsum(w_sorted) # from largest to smallest
249+
total_energy = xp.sum(w_sorted)
250+
K = int(xp.searchsorted(w_cumsum / total_energy, K, side="left") + 1)
251+
K = min(K, w_sorted.size)
252+
V = v_sorted[:K] # use top K component
253+
compress_data = V @ kspace_data
254+
return compress_data

0 commit comments

Comments
 (0)