Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/cli-conf/scenario2-2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- handlers:
- activation-block
- sampler:
- stack-of-spiral
- rotated-stack-of-spiral
- reconstructors:
- adjoint
#- sequential
Expand Down Expand Up @@ -35,6 +35,7 @@ handlers:
block_on: 20 # seconds
block_off: 20 #seconds
duration: 360 # seconds
delta_r2s: 1000 # millisecond^-1

sampler:
stack-of-spiral:
Expand All @@ -43,6 +44,13 @@ sampler:
nb_revolutions: 10
constant: true
spiral_name: "galilean"
rotated-stack-of-spiral:
acsz: 1
accelz: 1
nb_revolutions: 10
constant: false
spiral_name: "galilean"
rotate_frame_angle: 0

engine:
n_jobs: 1
Expand Down
2 changes: 2 additions & 0 deletions src/snake/core/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .samplers import (
EPI3dAcquisitionSampler,
StackOfSpiralSampler,
RotatedStackOfSpiralSampler,
NonCartesianAcquisitionSampler,
EVI3dAcquisitionSampler,
LoadTrajectorySampler,
Expand All @@ -15,5 +16,6 @@
"EPI3dAcquisitionSampler",
"EVI3dAcquisitionSampler",
"StackOfSpiralSampler",
"RotatedStackOfSpiralSampler",
"NonCartesianAcquisitionSampler",
]
44 changes: 44 additions & 0 deletions src/snake/core/sampling/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
stack_spiral_factory,
stacked_epi_factory,
evi_factory,
rotate_trajectory,
)
from snake.mrd_utils.utils import ACQ
from snake._meta import batched, EnvConfig
from mrinufft.io import read_trajectory
from collections.abc import Generator


class NonCartesianAcquisitionSampler(BaseSampler):
Expand Down Expand Up @@ -276,6 +278,48 @@ def _single_frame(self, sim_conf: SimConfig) -> NDArray:
)


class RotatedStackOfSpiralSampler(StackOfSpiralSampler):
"""
Spiral 2D Acquisition Handler to generate k-space data.

Parameters
----------
rotate_frame_angle: AngleRotation | int
Angle of rotation of the frame.
frame_index: int
Index of the frame.
**kwargs:
Extra arguments (smaps, n_jobs, backend etc...)
"""

__sampler_name__ = "rotated-stack-of-spiral"
rotate_frame_angle: AngleRotation | int = 0
frame_index: int = 0

def fix_angle_rotation(
self, frame: Generator[np.ndarray, None, None], angle: AngleRotation | float = 0
) -> Generator[np.ndarray, None, None]:
"""Rotate the trajectory by a given angle."""
for traj in frame:
yield from rotate_trajectory((x for x in [traj]), angle)

def get_next_frame(self, sim_conf: SimConfig) -> NDArray:
"""Generate the next rotated frame."""
base_frame = self._single_frame(sim_conf)
if self.constant or self.rotate_frame_angle == 0:
return base_frame
else:
self.frame_index += 1
rotate_frame_angle = np.pi * (self.rotate_frame_angle / 180)
base_frame_gen = (traj[None, ...] for traj in base_frame)
rotated_frame = self.fix_angle_rotation(
base_frame_gen, float(rotate_frame_angle * self.frame_index)
)
return np.concatenate(
[traj.astype(np.float32) for traj in rotated_frame], axis=0
)


class EPI3dAcquisitionSampler(BaseSampler):
"""Sampling pattern for EPI-3D."""

Expand Down
2 changes: 1 addition & 1 deletion src/snake/toolkit/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ def cleanup_cuda() -> None:
def make_hydra_cli(fun: callable) -> callable:
"""Create a Hydra CLI for the function."""
return hydra.main(
version_base=None, config_path="../../../cli-conf", config_name="config"
version_base=None, config_path="../../../cli-conf", config_name="scenario2-2d"
)(fun)
15 changes: 12 additions & 3 deletions src/snake/toolkit/reconstructors/pysap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
from numpy.typing import NDArray
from mrinufft.density import get_density

# Local imports
from snake.mrd_utils import (
Expand Down Expand Up @@ -155,9 +156,15 @@ def _reconstruct_nufft(
kwargs["density"] = None
else:
kwargs["density"] = self.density_compensation
method = self.density_compensation
if "stacked" in self.nufft_backend:
kwargs["z_index"] = "auto"

if isinstance(method, str):
method = get_density(method)
if not callable(method):
raise ValueError(f"Unknown density method: {method}")

nufft_operator = get_operator(
self.nufft_backend,
samples=traj,
Expand All @@ -171,9 +178,11 @@ def _reconstruct_nufft(
for i in tqdm(range(data_loader.n_frames)):
traj, data = data_loader.get_kspace_frame(i)
if data_loader.slice_2d:
nufft_operator.samples = traj.reshape(
data_loader.n_shots, -1, traj.shape[-1]
)[0, :, :2]
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])
nufft_operator.samples = traj[0, :, :2]
nufft_operator.density = method(
traj[:, :, :2], shape, backend=self.nufft_backend
)
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
for j in range(data.shape[1]):
final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))
Expand Down
Loading