Skip to content

Commit 16be647

Browse files
committed
fixup! Add arc length parameterization
1 parent d8902c6 commit 16be647

File tree

7 files changed

+222
-60
lines changed

7 files changed

+222
-60
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
======================
3+
Trajectory constraints
4+
======================
5+
6+
A collection of methods to make trajectories fit hardware constraints.
7+
8+
"""
9+
10+
# %%
11+
# Hereafter we illustrate different methods to reduce the gradient
12+
# strengths and slew rates required for the trajectory to match the
13+
# hardware constraints of MRI machines. A summary table is available
14+
# below.
15+
#
16+
17+
# %%
18+
# .. list-table:: Constraint fitting methods
19+
# :header-rows: 1
20+
#
21+
# * -
22+
# - Gradient strength
23+
# - Slew rate
24+
# - Path preserved
25+
# - Density preserved
26+
# * - Arc-length parameterization
27+
# - Yes
28+
# - No
29+
# - Yes
30+
# - No
31+
#
32+
33+
# Internal
34+
import mrinufft as mn
35+
from mrinufft.trajectories.utils import compute_gradients_and_slew_rates
36+
from utils import show_trajectory_full
37+
38+
# External
39+
import numpy as np
40+
41+
42+
# %%
43+
# Script options
44+
# ==============
45+
# These options are used in the examples below as default values for all trajectories.
46+
47+
# Acquisition parameters
48+
resolution = 1e-3 # Resolution in meters
49+
raster_time = 40e-3 # Raster time in milliseconds
50+
51+
# %%
52+
53+
# Trajectory parameters
54+
Nc = 16 # Number of shots
55+
Ns = 3000 # Number of samples per shot
56+
in_out = False # Choose between in-out or center-out trajectories
57+
nb_zigzags = 5 # Number of zigzags for base trajectories
58+
59+
# %%
60+
61+
# Display parameters
62+
figure_size = 10 # Figure size for trajectory plots
63+
subfigure_size = 6 # Figure size for subplots
64+
one_shot = 2 * Nc // 3 # Highlight one shot in particular
65+
sample_freq = 60 # Frequency of samples to display in the trajectory plots
66+
67+
# %%
68+
# We will be using a cone trajectory to showcase the different methods as
69+
# it switches several times between high gradients and slew rates.
70+
71+
original_trajectory = mn.initialize_2D_cones(Nc, Ns, in_out=in_out, nb_zigzags=nb_zigzags)
72+
73+
# %%
74+
# Arc-length parameterization
75+
# ===========================
76+
# Arc-length parameterization is the simplest method to reduce the gradient
77+
# strength as it resamples the trajectory to have a constant distance between
78+
# samples. This is technically the lowest gradient strength achievable while
79+
# preserving the path of the trajectory, but it does not preserve the k-space
80+
# density and can lead to high slew rates as shown below.
81+
82+
show_trajectory_full(original_trajectory, one_shot, subfigure_size, sample_freq)
83+
84+
grads, slews = compute_gradients_and_slew_rates(original_trajectory, resolution, raster_time)
85+
grad_max, slew_max = np.max(grads), np.max(slews)
86+
print(f"Max gradient: {grad_max:.3f} T/m, Max slew rate: {slew_max:.3f} T/m/ms")
87+
88+
# %%
89+
#
90+
91+
from mrinufft.trajectories.projection import parameterize_by_arc_length
92+
93+
projected_trajectory = parameterize_by_arc_length(original_trajectory)
94+
95+
# %%
96+
97+
show_trajectory_full(projected_trajectory, one_shot, subfigure_size, sample_freq)
98+
99+
grads, slews = compute_gradients_and_slew_rates(projected_trajectory, resolution, raster_time)
100+
grad_max, slew_max = np.max(grads), np.max(slews)
101+
print(f"Max gradient: {grad_max:.3f} T/m, Max slew rate: {slew_max:.3f} T/m/ms")

examples/utils.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
import matplotlib.pyplot as plt
1010

1111
# Internal imports
12-
from mrinufft import display_2D_trajectory, display_3D_trajectory, displayConfig
12+
from mrinufft import (
13+
display_2D_trajectory,
14+
display_3D_trajectory,
15+
displayConfig,
16+
display_gradients_simply,
17+
)
1318
from mrinufft.trajectories.utils import KMAX
1419

1520

@@ -33,6 +38,59 @@ def show_trajectory(trajectory, one_shot, figure_size):
3338
plt.show()
3439

3540

41+
def show_trajectory_full(trajectory, one_shot, figure_size, sample_freq=10):
42+
# General configuration
43+
fig = plt.figure(figsize=(3 * figure_size, figure_size))
44+
subfigs = fig.subfigures(1, 3, wspace=0)
45+
46+
# Trajectory display
47+
subfigs[0].suptitle("Trajectory", fontsize=displayConfig.fontsize, x=0.5, y=0.98)
48+
if trajectory.shape[-1] == 2:
49+
ax = display_2D_trajectory(
50+
trajectory,
51+
size=figure_size,
52+
one_shot=one_shot,
53+
subfigure=subfigs[0],
54+
)
55+
else:
56+
ax = display_3D_trajectory(
57+
trajectory,
58+
size=figure_size,
59+
one_shot=one_shot,
60+
per_plane=False,
61+
subfigure=subfigs[0],
62+
)
63+
ax.set_aspect("equal")
64+
for i in range(trajectory.shape[0]):
65+
ax.scatter(trajectory[i, ::sample_freq, 0], trajectory[i, ::sample_freq, 1], s=15)
66+
67+
# Gradient display
68+
subfigs[1].suptitle("Gradients", fontsize=displayConfig.fontsize, x=0.5, y=0.98)
69+
display_gradients_simply(
70+
trajectory,
71+
shot_ids=[one_shot],
72+
figsize=figure_size,
73+
subfigure=subfigs[1],
74+
uni_gradient="k",
75+
uni_signal="gray",
76+
)
77+
78+
# Slew rates display
79+
subfigs[2].suptitle("Slew rates", fontsize=displayConfig.fontsize, x=0.5, y=0.98)
80+
display_gradients_simply(
81+
np.diff(trajectory, axis=1),
82+
shot_ids=[one_shot],
83+
figsize=figure_size,
84+
subfigure=subfigs[2],
85+
uni_gradient="k",
86+
uni_signal="gray",
87+
)
88+
subfigs[2].axes[0].set_ylabel("Sx")
89+
subfigs[2].axes[1].set_ylabel("Sy")
90+
subfigs[2].axes[2].set_ylabel("|S|")
91+
plt.show()
92+
93+
3694
def show_trajectories(
3795
function, arguments, one_shot, subfig_size, dim="3D", axes=(0, 1)
3896
):

src/mrinufft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
displayConfig,
6565
display_2D_trajectory,
6666
display_3D_trajectory,
67+
display_gradients,
68+
display_gradients_simply,
6769
)
6870

6971
from .density import voronoi, cell_count, pipe, get_density
@@ -130,6 +132,8 @@
130132
"displayConfig",
131133
"display_2D_trajectory",
132134
"display_3D_trajectory",
135+
"display_gradients",
136+
"display_gradients_simply",
133137
]
134138

135139
from importlib.metadata import version, PackageNotFoundError

src/mrinufft/trajectories/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""Collection of trajectories and tools used for non-Cartesian MRI."""
22

3-
from .display import display_2D_trajectory, display_3D_trajectory, displayConfig
3+
from .display import (
4+
display_2D_trajectory,
5+
display_3D_trajectory,
6+
displayConfig,
7+
display_gradients,
8+
display_gradients_simply,
9+
)
410
from .gradients import patch_center_anomaly
511
from .inits import (
612
initialize_2D_random_walk,
713
initialize_2D_travelling_salesman,
814
initialize_3D_random_walk,
915
initialize_3D_travelling_salesman,
1016
)
11-
from .projection import fit_arc_length
17+
from .projection import parameterize_by_arc_length
1218
from .sampling import (
1319
create_chauffert_density,
1420
create_cutoff_decay_density,
@@ -117,6 +123,8 @@
117123
"displayConfig",
118124
"display_2D_trajectory",
119125
"display_3D_trajectory",
126+
"display_gradients",
127+
"display_gradients_simply",
120128
# projection
121-
"fit_arc_length",
129+
"parameterize_by_arc_length",
122130
]

src/mrinufft/trajectories/display.py

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,6 @@ def display_gradients_simply(
427427
shot_ids: tuple[int, ...] = (0,),
428428
figsize: float = 5,
429429
fill_area: bool = True,
430-
show_signal: bool = True,
431430
uni_signal: str | None = "gray",
432431
uni_gradient: str | None = None,
433432
subfigure: plt.Figure | None = None,
@@ -447,10 +446,6 @@ def display_gradients_simply(
447446
Fills the area under the curve for improved visibility and
448447
representation of the integral, aka trajectory.
449448
The default is `True`.
450-
show_signal : bool, optional
451-
Show an additional illustration of the signal as
452-
the modulated distance to the center.
453-
The default is `True`.
454449
uni_signal : str or None, optional
455450
Define whether the signal should be represented by a
456451
unique color given as argument or just by the default
@@ -471,13 +466,13 @@ def display_gradients_simply(
471466
Axes of the figure.
472467
"""
473468
# Setup figure and labels
474-
Nd = trajectory.shape[-1]
469+
nb_axes = trajectory.shape[-1] + 1
475470
if subfigure is None:
476-
fig = plt.figure(figsize=(figsize, figsize * (Nd + show_signal) / Nd))
471+
fig = plt.figure(figsize=(figsize, figsize * nb_axes / (nb_axes - 1)))
477472
else:
478473
fig = subfigure
479-
axes = fig.subplots(Nd + show_signal, 1)
480-
for i, ax in enumerate(axes[:Nd]):
474+
axes = fig.subplots(nb_axes, 1)
475+
for i, ax in enumerate(axes[:nb_axes - 1]):
481476
ax.set_ylabel("G{}".format(["x", "y", "z"][i]), fontsize=displayConfig.fontsize)
482477
axes[-1].set_xlabel("Time", fontsize=displayConfig.fontsize)
483478

@@ -489,50 +484,42 @@ def display_gradients_simply(
489484

490485
# Plot the curves for each axis
491486
gradients = np.diff(trajectory, axis=1)
492-
vmax = 1.1 * np.max(np.abs(gradients[shot_ids, ...]))
493-
x_axis = np.arange(gradients.shape[1])
487+
vmax = 1.1 * np.max(np.linalg.norm(gradients[shot_ids, ...], axis=-1, ord=1))
488+
for ax in axes[:-1]:
489+
ax.set_ylim((-vmax, vmax))
490+
axes[-1].set_ylim(-0.1 * vmax, vmax)
491+
492+
time_axis = np.arange(gradients.shape[1])
494493
colors = displayConfig.get_colorlist()
495494
for j, s_id in enumerate(shot_ids):
496-
for i, ax in enumerate(axes[:Nd]):
497-
ax.set_ylim((-vmax, vmax))
498-
color = (
499-
uni_gradient
500-
if uni_gradient is not None
501-
else colors[j % displayConfig.nb_colors]
502-
)
503-
ax.plot(x_axis, gradients[s_id, ..., i], color=color)
495+
color = (
496+
uni_gradient
497+
if uni_gradient is not None
498+
else colors[j % displayConfig.nb_colors]
499+
)
500+
501+
# Set each axis individually
502+
for i, ax in enumerate(axes[:-1]):
503+
ax.plot(time_axis, gradients[s_id, ..., i], color=color)
504504
if fill_area:
505505
ax.fill_between(
506-
x_axis,
506+
time_axis,
507507
gradients[s_id, ..., i],
508508
alpha=displayConfig.alpha,
509509
color=color,
510510
)
511511

512-
# Return axes alone
513-
if not show_signal:
514-
return axes
515-
516-
# Show signal as modulated distance to center
517-
distances = np.linalg.norm(trajectory[shot_ids, 1:-1], axis=-1)
518-
distances = np.tile(distances.reshape((len(shot_ids), -1, 1)), (1, 1, 10))
519-
signal = 1 - distances.reshape((len(shot_ids), -1)) / np.max(distances)
520-
signal = (
521-
signal * np.exp(2j * np.pi * figsize / 100 * np.arange(signal.shape[1]))
522-
).real
523-
signal = signal * np.abs(signal) ** 3
524-
525-
colors = displayConfig.get_colorlist()
526-
# Show signal for each requested shot
527-
axes[-1].set_ylim((-1, 1))
528-
axes[-1].set_ylabel("Signal", fontsize=displayConfig.fontsize)
529-
for j in range(len(shot_ids)):
530-
color = (
531-
uni_signal
532-
if (uni_signal is not None)
533-
else colors[j % displayConfig.nb_colors]
534-
)
535-
axes[-1].plot(np.arange(signal.shape[1]), signal[j], color=color)
512+
# Set the norm axis if requested
513+
gradient_norm = np.linalg.norm(gradients[s_id], axis=-1)
514+
axes[-1].set_ylabel("|G|", fontsize=displayConfig.fontsize)
515+
axes[-1].plot(gradient_norm, color=color)
516+
if fill_area:
517+
axes[-1].fill_between(
518+
time_axis,
519+
gradient_norm,
520+
alpha=displayConfig.alpha,
521+
color=color,
522+
)
536523
return axes
537524

538525

@@ -541,7 +528,7 @@ def display_gradients(
541528
shot_ids: tuple[int, ...] = (0,),
542529
figsize: float = 5,
543530
fill_area: bool = True,
544-
show_signal: bool = True,
531+
show_norm: bool = True,
545532
uni_signal: str | None = "gray",
546533
uni_gradient: str | None = None,
547534
subfigure: plt.Figure | plt.Axes | None = None,
@@ -567,7 +554,7 @@ def display_gradients(
567554
Fills the area under the curve for improved visibility and
568555
representation of the integral, aka trajectory.
569556
The default is `True`.
570-
show_signal : bool, optional
557+
show_norm : bool, optional
571558
Show an additional illustration of the signal as
572559
the modulated distance to the center.
573560
The default is `True`.
@@ -619,7 +606,7 @@ def display_gradients(
619606
shot_ids,
620607
figsize,
621608
fill_area,
622-
show_signal,
609+
show_norm,
623610
uni_signal,
624611
uni_gradient,
625612
subfigure,
@@ -633,7 +620,7 @@ def display_gradients(
633620
fontsize=displayConfig.small_fontsize,
634621
)
635622
axes[-1].set_xlabel("Time (ms)", fontsize=displayConfig.small_fontsize)
636-
if show_signal:
623+
if show_norm:
637624
axes[-1].set_ylabel("Signal (a.u.)", fontsize=displayConfig.small_fontsize)
638625

639626
# Update axis ticks with rescaled values
@@ -642,7 +629,7 @@ def display_gradients(
642629
if ax == axes[-1]:
643630
ax.xaxis.set_tick_params(labelbottom=True)
644631
ticks = ax.get_xticks()
645-
scale = (0.1 if (show_signal and ax == axes[-1]) else 1) * raster_time
632+
scale = (0.1 if (show_norm and ax == axes[-1]) else 1) * raster_time
646633
locator = mticker.FixedLocator(ticks)
647634
formatter = mticker.FixedFormatter(np.around(scale * ticks, 2))
648635
ax.xaxis.set_major_locator(locator)
@@ -663,7 +650,7 @@ def display_gradients(
663650
scale = 1e3 * scale # Convert from T/m to mT/m
664651
locator = mticker.FixedLocator(ticks)
665652
formatter = mticker.FixedFormatter(np.around(scale * ticks, 1))
666-
if not show_signal or ax != axes[-1]:
653+
if not show_norm or ax != axes[-1]:
667654
ax.yaxis.set_major_locator(locator)
668655
ax.yaxis.set_major_formatter(formatter)
669656

0 commit comments

Comments
 (0)