Skip to content

Commit ebd8bb6

Browse files
committed
load from mrd + 2d adjoint reconstructor
1 parent 2442e66 commit ebd8bb6

File tree

8 files changed

+146
-16
lines changed

8 files changed

+146
-16
lines changed

src/cli-conf/scenario2-2d.yaml

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# This files contains the configuration to reproduce the scenario 1 of the Snkf paper.
2+
3+
defaults:
4+
- base_config
5+
- handlers:
6+
- activation-block
7+
- sampler:
8+
- stack-of-spiral
9+
- reconstructors:
10+
- adjoint
11+
#- sequential
12+
- _self_
13+
14+
cache_dir: ${oc.env:PWD}/cache
15+
result_dir: results/scenario2
16+
filename: ${cache_dir}/scenario2_2d_${engine.model}_${engine.snr}_${sampler.stack-of-spiral.constant}_${sampler.stack-of-spiral.accelz}.mrd
17+
18+
sim_conf:
19+
max_sim_time: 360
20+
seq: {TR: 50, TE: 25, FA: 12}
21+
hardware:
22+
n_coils: 1
23+
dwell_time_ms: 0.001
24+
shape: [60, 72, 60]
25+
fov_mm: [181.0, 217.0, 181.0]
26+
27+
phantom:
28+
name: brainweb
29+
sub_id: 4
30+
tissue_file: "tissue_7T"
31+
32+
handlers:
33+
activation-block:
34+
event_name: block_on
35+
block_on: 20 # seconds
36+
block_off: 20 #seconds
37+
duration: 360 # seconds
38+
#delta_r2s: 1000 # millisecond^-1
39+
40+
sampler:
41+
stack-of-spiral:
42+
acsz: 1
43+
accelz: 1
44+
nb_revolutions: 10
45+
constant: true
46+
spiral_name: "galilean"
47+
48+
engine:
49+
n_jobs: 1
50+
chunk_size: 10
51+
model: "simple"
52+
snr: 10000
53+
nufft_backend: "gpuNUFFT"
54+
slice_2d: true
55+
56+
reconstructors:
57+
adjoint:
58+
nufft_backend: "gpuNUFFT"
59+
density_compensation: "pipe"
60+
# sequential:
61+
# nufft_backend: "gpuNUFFT"
62+
# density_compensation: false
63+
# restart_strategy: WARM
64+
# max_iter_per_frame: 50
65+
# wavelet: "sym4"
66+
67+
68+
69+
70+
71+
hydra:
72+
job:
73+
chdir: true
74+
75+
run:
76+
dir: ${result_dir}/outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}
77+
sweep:
78+
dir: ${result_dir}/multirun/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}
79+
subdir: ${hydra.job.num}
80+
81+
callbacks:
82+
# gather_files:
83+
# _target_: hydra_callbacks.MultiRunGatherer
84+
# aggregator:
85+
# _partial_: true
86+
# _target_: snkf.cli.utils.aggregate_results
87+
88+
log_job:
89+
_target_: hydra.experimental.callbacks.LogJobReturnCallback
90+
latest_run:
91+
_target_: hydra_callbacks.LatestRunLink
92+
run_base_dir: ${result_dir}/outputs
93+
multirun_base_dir: ${result_dir}/multirun

src/snake/core/engine/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,17 @@ def __call__(
209209
`_job_model_simple` methods.
210210
"""
211211
# Create the base dataset
212-
make_base_mrd(filename, sampler, phantom, sim_conf, handlers, smaps, coil_cov, self.model, self.slice_2d)
212+
make_base_mrd(
213+
filename,
214+
sampler,
215+
phantom,
216+
sim_conf,
217+
handlers,
218+
smaps,
219+
coil_cov,
220+
self.model,
221+
self.slice_2d,
222+
)
213223

214224
# Guesstimate the workload
215225
if worker_chunk_size <= 0:

src/snake/mrd_utils/loader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ def n_shots(self) -> int:
211211
"""
212212
return self.header.encoding[0].encodingLimits.kspace_encoding_step_1.maximum
213213

214+
@property
215+
def engine_model(self) -> str:
216+
"""Get the engine model."""
217+
return self.header.userParameters.userParameterString[0].value
218+
219+
@property
220+
def slice_2d(self) -> bool:
221+
"""Is the slice 2D."""
222+
return bool(self.header.userParameters.userParameterString[1].value)
223+
214224
#############
215225
# Get data #
216226
#############

src/snake/mrd_utils/writer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
log = logging.getLogger(__name__)
2323

2424

25-
def get_mrd_header(sim_conf: SimConfig, engine: str, model: str, slice_2d: bool) -> mrd.xsd.ismrmrdHeader:
25+
def get_mrd_header(
26+
sim_conf: SimConfig, engine: str, model: str, slice_2d: bool
27+
) -> mrd.xsd.ismrmrdHeader:
2628
"""Create a MRD Header for snake-fmri data."""
2729
H = mrd.xsd.ismrmrdHeader()
2830
# Experimental conditions
@@ -85,7 +87,7 @@ def get_mrd_header(sim_conf: SimConfig, engine: str, model: str, slice_2d: bool)
8587
("engine_model", model),
8688
("slice_2d", str(slice_2d)),
8789
]
88-
]
90+
],
8991
)
9092

9193
return H

src/snake/toolkit/cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
""""Command Line Interface for SNAKE."""
1+
""" "Command Line Interface for SNAKE."""

src/snake/toolkit/cli/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,5 @@ def cleanup_cuda() -> None:
125125
def make_hydra_cli(fun: callable) -> callable:
126126
"""Create a Hydra CLI for the function."""
127127
return hydra.main(
128-
version_base=None, config_path="../../../cli-conf", config_name="config"
128+
version_base=None, config_path="../../../cli-conf", config_name="scenario2-2d"
129129
)(fun)

src/snake/toolkit/cli/reconstruction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def reconstruction(cfg: DictConfig) -> None:
5757
data_rec_file = Path(f"data_rec_{rec_str}.npy")
5858
log.info(f"Using {name} reconstructor")
5959
rec.setup(sim_conf)
60-
rec_data = rec.reconstruct(data_loader, sim_conf)
60+
rec_data = rec.reconstruct(data_loader, sim_conf, data_loader.slice_2d)
6161
log.info(f"Reconstruction done with {name}")
6262
# Save the reconstruction
6363
np.save(data_rec_file, rec_data)

src/snake/toolkit/reconstructors/pysap.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,20 @@ def setup(self, sim_conf: SimConfig) -> None:
7373
"""Initialize Reconstructor."""
7474
pass
7575

76-
def reconstruct(self, data_loader: MRDLoader, sim_conf: SimConfig) -> NDArray:
76+
def reconstruct(
77+
self, data_loader: MRDLoader, sim_conf: SimConfig, slice_2d: bool
78+
) -> NDArray:
7779
"""Reconstruct data with zero-filled method."""
7880
with data_loader:
7981
if isinstance(data_loader, CartesianFrameDataLoader):
80-
return self._reconstruct_cartesian(data_loader, sim_conf)
82+
return self._reconstruct_cartesian(data_loader, sim_conf, slice_2d)
8183
elif isinstance(data_loader, NonCartesianFrameDataLoader):
82-
return self._reconstruct_nufft(data_loader, sim_conf)
84+
return self._reconstruct_nufft(data_loader, sim_conf, slice_2d)
8385
else:
8486
raise ValueError("Unknown dataloader")
8587

8688
def _reconstruct_cartesian(
87-
self, data_loader: CartesianFrameDataLoader, sim_conf: SimConfig
89+
self, data_loader: CartesianFrameDataLoader, sim_conf: SimConfig, slice_2d
8890
) -> NDArray:
8991
smaps = data_loader.get_smaps()
9092
if smaps is None and data_loader.n_coils > 1:
@@ -114,7 +116,6 @@ def _reconstruct_cartesian(
114116
): idx
115117
for idx in range(data_loader.n_frames)
116118
}
117-
118119
for future in as_completed(futures):
119120
future.result()
120121
pbar.update(1)
@@ -126,16 +127,21 @@ def _reconstruct_cartesian(
126127
return final_images
127128

128129
def _reconstruct_nufft(
129-
self, data_loader: NonCartesianFrameDataLoader, sim_conf: SimConfig
130+
self, data_loader: NonCartesianFrameDataLoader, sim_conf: SimConfig, slice_2d
130131
) -> NDArray:
131132
"""Reconstruct data with nufft method."""
132133
from mrinufft import get_operator
133134

134135
smaps = data_loader.get_smaps()
135-
136+
shape = data_loader.shape
136137
traj, kspace_data = data_loader.get_kspace_frame(0)
138+
139+
if slice_2d:
140+
shape = data_loader.shape[:2]
141+
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])[0, :, :2]
142+
137143
kwargs = dict(
138-
shape=data_loader.shape,
144+
shape=shape,
139145
n_coils=data_loader.n_coils,
140146
smaps=smaps,
141147
)
@@ -146,6 +152,7 @@ def _reconstruct_nufft(
146152
kwargs["density"] = self.density_compensation
147153
if "stacked" in self.nufft_backend:
148154
kwargs["z_index"] = "auto"
155+
149156
nufft_operator = get_operator(
150157
self.nufft_backend,
151158
samples=traj,
@@ -158,8 +165,16 @@ def _reconstruct_nufft(
158165

159166
for i in tqdm(range(data_loader.n_frames)):
160167
traj, data = data_loader.get_kspace_frame(i)
161-
nufft_operator.samples = traj
162-
final_images[i] = abs(nufft_operator.adj_op(data))
168+
if slice_2d:
169+
nufft_operator.samples = traj.reshape(
170+
data_loader.n_shots, -1, traj.shape[-1]
171+
)[0, :, :2]
172+
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
173+
for j in range(data.shape[1]):
174+
final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))
175+
else:
176+
nufft_operator.samples = traj
177+
final_images[i] = abs(nufft_operator.adj_op(data))
163178
return final_images
164179

165180

0 commit comments

Comments
 (0)