diff --git a/examples/trajectories/example_2D_trajectories.py b/examples/trajectories/example_2D_trajectories.py index b0eee99bb..9852a0197 100644 --- a/examples/trajectories/example_2D_trajectories.py +++ b/examples/trajectories/example_2D_trajectories.py @@ -32,7 +32,7 @@ # Trajectory parameters Nc = 24 # Number of shots Ns = 256 # Number of samples per shot -in_out = True # Choose between in-out or center-out trajectories +in_out = False # Choose between in-out or center-out trajectories tilt = "uniform" # Choose the angular distance between shots nb_repetitions = 6 # Number of strips when relevant seed = 0 # Seed for random trajectories diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 4b67e1e15..e63fed01f 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -6,6 +6,8 @@ from numpy.typing import NDArray from scipy.interpolate import CubicSpline, interp1d from scipy.stats import norm +import inspect +from functools import wraps from scipy.optimize import minimize_scalar from joblib import Parallel, delayed @@ -15,9 +17,13 @@ VDSorder, VDSpdf, initialize_tilt, + DEFAULT_RESOLUTION, DEFAULT_GMAX, DEFAULT_RASTER_TIME, DEFAULT_SMAX, + unnormalize_trajectory, + convert_gradients_to_trajectory, + convert_trajectory_to_gradients, Gammas, ) @@ -380,6 +386,31 @@ def unepify(trajectory: NDArray, Ns_readouts: int, Ns_transitions: int) -> NDArr return trajectory +def _set_defaults_gradient_calc( + kspace_end_loc: NDArray, + kspace_start_loc: Optional[NDArray] = None, + end_gradients: Optional[NDArray] = None, + start_gradients: Optional[NDArray] = None, +): + kspace_end_loc = np.atleast_2d(kspace_end_loc) + if kspace_start_loc is None: + kspace_start_loc = np.zeros_like(kspace_end_loc) + if start_gradients is None: + start_gradients = np.zeros_like(kspace_end_loc) + if end_gradients is None: + end_gradients = np.zeros_like(kspace_end_loc) + kspace_start_loc = np.atleast_2d(kspace_start_loc) + start_gradients = np.atleast_2d(start_gradients) + end_gradients = np.atleast_2d(end_gradients) + assert ( + kspace_start_loc.shape + == kspace_end_loc.shape + == start_gradients.shape + == end_gradients.shape + ), "All input arrays must have shape (nb_shots, nb_dimension)" + return kspace_end_loc, kspace_start_loc, start_gradients, end_gradients + + def _trapezoidal_area(gs, ge, gi, n_down, n_up, n_pl): """Calculate the area traversed by the trapezoidal gradient waveform.""" return 0.5 * (gs + gi) * (n_down + 1) + 0.5 * (ge + gi) * (n_up - 1) + n_pl * gi @@ -417,7 +448,7 @@ def _plateau_value(gs, ge, n_down, n_up, n_pl, area_needed): def get_gradient_times_to_travel( - kspace_end_loc: Optional[NDArray] = None, + kspace_end_loc: NDArray, kspace_start_loc: Optional[NDArray] = None, end_gradients: Optional[NDArray] = None, start_gradients: Optional[NDArray] = None, @@ -472,6 +503,14 @@ def get_gradient_times_to_travel( To directly get the waveforms required. This is most-likely what you want to use. """ + kspace_end_loc, kspace_start_loc, start_gradients, end_gradients = ( + _set_defaults_gradient_calc( + kspace_end_loc, + kspace_start_loc, + end_gradients, + start_gradients, + ) + ) area_needed = (kspace_end_loc - kspace_start_loc) / gamma / raster_time def solve_gi_min_plateau(gs, ge, area): @@ -582,23 +621,14 @@ def get_gradient_amplitudes_to_travel_for_set_time( - The returned gradients are suitable for use in MRI pulse sequence design, ensuring compliance with specified hardware constraints. """ - kspace_end_loc = np.atleast_2d(kspace_end_loc) - if kspace_start_loc is None: - kspace_start_loc = np.zeros_like(kspace_end_loc) - if start_gradients is None: - start_gradients = np.zeros_like(kspace_end_loc) - if end_gradients is None: - end_gradients = np.zeros_like(kspace_end_loc) - kspace_start_loc = np.atleast_2d(kspace_start_loc) - start_gradients = np.atleast_2d(start_gradients) - end_gradients = np.atleast_2d(end_gradients) - - assert ( - kspace_start_loc.shape - == kspace_end_loc.shape - == start_gradients.shape - == end_gradients.shape - ), "All input arrays must have shape (nb_shots, nb_dimension)" + kspace_end_loc, kspace_start_loc, start_gradients, end_gradients = ( + _set_defaults_gradient_calc( + kspace_end_loc, + kspace_start_loc, + end_gradients, + start_gradients, + ) + ) if nb_raster_points is None: # Calculate the number of time steps based on the area needed n_ramp_down, n_ramp_up, n_plateau, gi = get_gradient_times_to_travel( @@ -1268,3 +1298,159 @@ def stack_random( new_trajectory[i, :, :, 2] = loc return new_trajectory.reshape(-1, Ns, 3) + + +def _add_slew_ramp_to_traj_func( + func: Callable, + func_kwargs: dict, + ramp_to_index: int, + resolution: float, + raster_time: float, + gamma: float, + smax: float, +): + traj = func(**func_kwargs) + unnormalized_traj = unnormalize_trajectory(traj, resolution=resolution) + gradients, initial_positions = convert_trajectory_to_gradients( + traj, resolution=resolution, raster_time=raster_time, gamma=gamma + ) + gradients_to_reach = gradients[:, ramp_to_index] + # Calculate the number of time steps for ramps + n_ramp_down, n_ramp_up, n_plateau, gi = get_gradient_times_to_travel( + kspace_end_loc=unnormalized_traj[:, ramp_to_index], + end_gradients=gradients_to_reach, + gamma=gamma, + raster_time=raster_time, + smax=smax, + n_jobs=-1, # Use all available cores + ) + # Update the Ns of the trajectory to ensure we still give + # same Ns as users expect. We use extra 2 points as buffer. + n_slew_ramp = np.max(n_ramp_down + n_ramp_up + n_plateau) + func_kwargs["Ns"] -= n_slew_ramp - ramp_to_index + new_traj = func(**func_kwargs) + # Re-calculate the gradients + unnormalized_traj = unnormalize_trajectory(new_traj, resolution=resolution) + gradients, initial_positions = convert_trajectory_to_gradients( + new_traj, resolution=resolution, raster_time=raster_time, gamma=gamma + ) + gradients_to_reach = gradients[:, ramp_to_index] + ramp_up_gradients = get_gradient_amplitudes_to_travel_for_set_time( + kspace_end_loc=unnormalized_traj[:, ramp_to_index], + end_gradients=gradients_to_reach, + nb_raster_points=n_slew_ramp, + gamma=gamma, + raster_time=raster_time, + smax=smax, + n_jobs=-1, # Use all available core + )[:, :-1] + ramp_up_traj = convert_gradients_to_trajectory( + gradients=ramp_up_gradients, + initial_positions=initial_positions, + resolution=resolution, + raster_time=raster_time, + gamma=gamma, + ) + return np.hstack([ramp_up_traj, new_traj[:, ramp_to_index:]]) + + +def add_slew_ramp( + func: Optional[Callable] = None, + ramp_to_index: int = 5, + resolution: float | NDArray = DEFAULT_RESOLUTION, + raster_time: float = DEFAULT_RASTER_TIME, + gamma: float = Gammas.Hydrogen, + smax: float = DEFAULT_SMAX, + slew_ramp_disable: bool = False, +) -> Callable: + """Add slew-compatible ramps to a trajectory function. + + This decorator modifies a trajectory function to include + slew rate ramps, ensuring that the trajectory adheres to + the maximum slew rate and gradient amplitude constraints. + The ramps are applied to the gradients of the trajectory + at the specified `ramp_to_index`, which is by-default the + index of the 5th readout sample. + Note that this decorator does not change the length of the original + trajectory. + + Parameters + ---------- + func : Optional[Callable], optional + The trajectory function to decorate. If not provided, + the decorator can be used without arguments. + ramp_to_index : int, optional + The index in the trajectory where the slew ramp should be applied, + by default 5. This is typically the index of the first readout sample. + resolution : float or NDArray, optional + The resolution of the trajectory, by default DEFAULT_RESOLUTION. + This can be a single float or an array-like of shape (3,). + raster_time : float, optional + The time interval between samples in the trajectory, by default + DEFAULT_RASTER_TIME. + gamma : float, optional + The gyromagnetic ratio in Hz/T, by default Gammas.Hydrogen. + smax : float, optional + The maximum slew rate in T/m/s, by default DEFAULT_SMAX. + slew_ramp_disable : bool, optional + If True, disables the slew ramp and returns the trajectory as is, + by default False. This is useful for in-out trajectories where + the slew ramp is not needed. + + Returns + ------- + Callable + A decorator that modifies the trajectory function to include + slew rate ramps. + + Notes + ----- + - The decorator modifies the trajectory function to ensure that the + gradients at the specified `ramp_to_index` are adjusted to comply with + the maximum slew rate and gradient amplitude constraints. + - If `slew_ramp_disable` is set to True, the trajectory function will + return the trajectory as is, without applying any slew ramps. + - The decorator can be used with or without providing a function. + - If used without a function, it returns a decorator that can be applied later. + - The decorated function should accept parameters like `smax`, `resolution`, + `raster_time`, `gamma`, and `ramp_to_index` to control the behavior of the + slew ramping. + """ + + def decorator(trajectory_func): + sig = inspect.signature(trajectory_func) + + @wraps(trajectory_func) + def wrapped(*args, **kwargs) -> NDArray: + # This allows users to also call the trajectory function + # directly giving these args. + _smax = kwargs.pop("smax", smax) + _resolution = kwargs.pop("resolution", resolution) + _raster_time = kwargs.pop("raster_time", raster_time) + _gamma = kwargs.pop("gamma", gamma) + _ramp_to_index = kwargs.pop("ramp_to_index", ramp_to_index) + _slew_ramp_disable = kwargs.pop("slew_ramp_disable", slew_ramp_disable) + # Bind all args (positional and keyword) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + in_out = bound.arguments.get("in_out", False) + if in_out or _slew_ramp_disable: + traj = trajectory_func(*args, **kwargs) + # Send the trajectory as is for in-out trajectories + return traj + return _add_slew_ramp_to_traj_func( + trajectory_func, + bound.arguments, + ramp_to_index=_ramp_to_index, + resolution=_resolution, + raster_time=_raster_time, + gamma=_gamma, + smax=_smax, + ) + + return wrapped + + if func is not None and callable(func): + return decorator(func) + + return decorator diff --git a/src/mrinufft/trajectories/trajectory2D.py b/src/mrinufft/trajectories/trajectory2D.py index a955b1727..509fde7fc 100644 --- a/src/mrinufft/trajectories/trajectory2D.py +++ b/src/mrinufft/trajectories/trajectory2D.py @@ -9,7 +9,7 @@ from .gradients import patch_center_anomaly from .maths import R2D, compute_coprime_factors, is_from_fibonacci_sequence -from .tools import rotate +from .tools import rotate, add_slew_ramp from .utils import KMAX, initialize_algebraic_spiral, initialize_tilt ##################### @@ -17,6 +17,7 @@ ##################### +@add_slew_ramp def initialize_2D_radial( Nc: int, Ns: int, tilt: str | float = "uniform", in_out: bool = False ) -> NDArray: @@ -51,6 +52,7 @@ def initialize_2D_radial( return trajectory +@add_slew_ramp def initialize_2D_spiral( Nc: int, Ns: int, @@ -229,6 +231,7 @@ def initialize_2D_fibonacci_spiral( return trajectory +@add_slew_ramp def initialize_2D_cones( Nc: int, Ns: int, @@ -275,6 +278,7 @@ def initialize_2D_cones( return trajectory +@add_slew_ramp def initialize_2D_sinusoide( Nc: int, Ns: int, @@ -415,6 +419,7 @@ def initialize_2D_rings(Nc: int, Ns: int, nb_rings: int) -> NDArray: return KMAX * np.array(trajectory) +@add_slew_ramp def initialize_2D_rosette( Nc: int, Ns: int, in_out: bool = False, coprime_index: int = 0 ) -> NDArray: @@ -458,6 +463,7 @@ def initialize_2D_rosette( return trajectory +@add_slew_ramp def initialize_2D_polar_lissajous( Nc: int, Ns: int, in_out: bool = False, nb_segments: int = 1, coprime_index: int = 0 ) -> NDArray: diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index 5ec9f2bca..260a11a86 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -17,7 +17,7 @@ Rz, generate_fibonacci_circle, ) -from .tools import conify, duplicate_along_axes, epify, precess, stack +from .tools import conify, duplicate_along_axes, epify, precess, stack, add_slew_ramp from .trajectory2D import initialize_2D_radial, initialize_2D_spiral from .utils import KMAX, Packings, Spirals, initialize_shape_norm, initialize_tilt @@ -75,6 +75,7 @@ def initialize_3D_phyllotaxis_radial( return trajectory +@add_slew_ramp def initialize_3D_golden_means_radial( Nc: int, Ns: int, in_out: bool = False ) -> NDArray: @@ -129,6 +130,7 @@ def initialize_3D_golden_means_radial( return KMAX * trajectory +@add_slew_ramp def initialize_3D_wong_radial( Nc: int, Ns: int, nb_interleaves: int = 1, in_out: bool = False ) -> NDArray: @@ -242,6 +244,7 @@ def initialize_3D_park_radial( ############################ +@add_slew_ramp def initialize_3D_cones( Nc: int, Ns: int, @@ -296,6 +299,7 @@ def initialize_3D_cones( spiral=spiral, in_out=in_out, nb_revolutions=nb_zigzags, + slew_ramp_disable=True, ) # Estimate best cone angle based on the ratio between @@ -525,6 +529,7 @@ def initialize_3D_wave_caipi( return KMAX * trajectory +@add_slew_ramp def initialize_3D_seiffert_spiral( Nc: int, Ns: int, diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index 4f5a0e077..e74a11400 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -1,8 +1,10 @@ """Utility functions in general.""" +import inspect +from functools import wraps from enum import Enum, EnumMeta from numbers import Real -from typing import Any, Literal +from typing import Any import numpy as np from numpy.typing import NDArray