diff --git a/src/cli-conf/scenario2-2d.yaml b/src/cli-conf/scenario2-2d.yaml index 011a3785..317475eb 100644 --- a/src/cli-conf/scenario2-2d.yaml +++ b/src/cli-conf/scenario2-2d.yaml @@ -5,7 +5,7 @@ defaults: - handlers: - activation-block - sampler: - - stack-of-spiral + - rotated-stack-of-spiral - reconstructors: - adjoint #- sequential @@ -35,6 +35,7 @@ handlers: block_on: 20 # seconds block_off: 20 #seconds duration: 360 # seconds + delta_r2s: 1000 # millisecond^-1 sampler: stack-of-spiral: @@ -43,6 +44,13 @@ sampler: nb_revolutions: 10 constant: true spiral_name: "galilean" + rotated-stack-of-spiral: + acsz: 1 + accelz: 1 + nb_revolutions: 10 + constant: false + spiral_name: "galilean" + rotate_frame_angle: 0 engine: n_jobs: 1 diff --git a/src/snake/core/sampling/__init__.py b/src/snake/core/sampling/__init__.py index ad86234f..b763f9db 100644 --- a/src/snake/core/sampling/__init__.py +++ b/src/snake/core/sampling/__init__.py @@ -4,6 +4,7 @@ from .samplers import ( EPI3dAcquisitionSampler, StackOfSpiralSampler, + RotatedStackOfSpiralSampler, NonCartesianAcquisitionSampler, EVI3dAcquisitionSampler, LoadTrajectorySampler, @@ -15,5 +16,6 @@ "EPI3dAcquisitionSampler", "EVI3dAcquisitionSampler", "StackOfSpiralSampler", + "RotatedStackOfSpiralSampler", "NonCartesianAcquisitionSampler", ] diff --git a/src/snake/core/sampling/samplers.py b/src/snake/core/sampling/samplers.py index 405856c5..51237619 100644 --- a/src/snake/core/sampling/samplers.py +++ b/src/snake/core/sampling/samplers.py @@ -14,10 +14,12 @@ stack_spiral_factory, stacked_epi_factory, evi_factory, + rotate_trajectory, ) from snake.mrd_utils.utils import ACQ from snake._meta import batched, EnvConfig from mrinufft.io import read_trajectory +from collections.abc import Generator class NonCartesianAcquisitionSampler(BaseSampler): @@ -276,6 +278,48 @@ def _single_frame(self, sim_conf: SimConfig) -> NDArray: ) +class RotatedStackOfSpiralSampler(StackOfSpiralSampler): + """ + Spiral 2D Acquisition Handler to generate k-space data. + + Parameters + ---------- + rotate_frame_angle: AngleRotation | int + Angle of rotation of the frame. + frame_index: int + Index of the frame. + **kwargs: + Extra arguments (smaps, n_jobs, backend etc...) + """ + + __sampler_name__ = "rotated-stack-of-spiral" + rotate_frame_angle: AngleRotation | int = 0 + frame_index: int = 0 + + def fix_angle_rotation( + self, frame: Generator[np.ndarray, None, None], angle: AngleRotation | float = 0 + ) -> Generator[np.ndarray, None, None]: + """Rotate the trajectory by a given angle.""" + for traj in frame: + yield from rotate_trajectory((x for x in [traj]), angle) + + def get_next_frame(self, sim_conf: SimConfig) -> NDArray: + """Generate the next rotated frame.""" + base_frame = self._single_frame(sim_conf) + if self.constant or self.rotate_frame_angle == 0: + return base_frame + else: + self.frame_index += 1 + rotate_frame_angle = np.pi * (self.rotate_frame_angle / 180) + base_frame_gen = (traj[None, ...] for traj in base_frame) + rotated_frame = self.fix_angle_rotation( + base_frame_gen, float(rotate_frame_angle * self.frame_index) + ) + return np.concatenate( + [traj.astype(np.float32) for traj in rotated_frame], axis=0 + ) + + class EPI3dAcquisitionSampler(BaseSampler): """Sampling pattern for EPI-3D.""" diff --git a/src/snake/toolkit/cli/config.py b/src/snake/toolkit/cli/config.py index 7611cf82..39d2a4e2 100644 --- a/src/snake/toolkit/cli/config.py +++ b/src/snake/toolkit/cli/config.py @@ -125,5 +125,5 @@ def cleanup_cuda() -> None: def make_hydra_cli(fun: callable) -> callable: """Create a Hydra CLI for the function.""" return hydra.main( - version_base=None, config_path="../../../cli-conf", config_name="config" + version_base=None, config_path="../../../cli-conf", config_name="scenario2-2d" )(fun) diff --git a/src/snake/toolkit/reconstructors/pysap.py b/src/snake/toolkit/reconstructors/pysap.py index d2e3049a..051fb984 100644 --- a/src/snake/toolkit/reconstructors/pysap.py +++ b/src/snake/toolkit/reconstructors/pysap.py @@ -8,6 +8,7 @@ import numpy as np from numpy.typing import NDArray +from mrinufft.density import get_density # Local imports from snake.mrd_utils import ( @@ -155,9 +156,15 @@ def _reconstruct_nufft( kwargs["density"] = None else: kwargs["density"] = self.density_compensation + method = self.density_compensation if "stacked" in self.nufft_backend: kwargs["z_index"] = "auto" + if isinstance(method, str): + method = get_density(method) + if not callable(method): + raise ValueError(f"Unknown density method: {method}") + nufft_operator = get_operator( self.nufft_backend, samples=traj, @@ -171,9 +178,11 @@ def _reconstruct_nufft( for i in tqdm(range(data_loader.n_frames)): traj, data = data_loader.get_kspace_frame(i) if data_loader.slice_2d: - nufft_operator.samples = traj.reshape( - data_loader.n_shots, -1, traj.shape[-1] - )[0, :, :2] + traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1]) + nufft_operator.samples = traj[0, :, :2] + nufft_operator.density = method( + traj[:, :, :2], shape, backend=self.nufft_backend + ) data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1)) for j in range(data.shape[1]): final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))