Skip to content

Commit f6198ca

Browse files
committed
slicedataloader for zerofilled
1 parent d0bd0e7 commit f6198ca

File tree

5 files changed

+71
-21
lines changed

5 files changed

+71
-21
lines changed

src/snake/mrd_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
CartesianFrameDataLoader,
55
MRDLoader,
66
NonCartesianFrameDataLoader,
7+
SliceDataloader,
78
parse_sim_conf,
89
parse_waveform_information,
910
read_mrd_header,
@@ -16,6 +17,7 @@
1617
"MRDLoader",
1718
"CartesianFrameDataLoader",
1819
"NonCartesianFrameDataLoader",
20+
"SliceDataloader",
1921
"parse_sim_conf",
2022
"parse_waveform_information",
2123
"make_base_mrd",

src/snake/mrd_utils/loader.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,48 @@ def get_coil_cov(self) -> NDArray | None:
378378
"""Load the coil covariance from the dataset."""
379379
return self._get_image_data("coil_cov")
380380

381+
class SliceDataloader(MRDLoader):
382+
"""Load slice MRD files k-space frames iteratively."""
383+
384+
def __init__(self,
385+
frame_dl: MRDLoader,
386+
):
387+
super().__init__(
388+
filename=frame_dl._filename,
389+
dataset_name=frame_dl._dataset_name,
390+
writeable=frame_dl._writeable,
391+
swmr=frame_dl._swmr,
392+
squeeze_dims=frame_dl._squeeze_dims
393+
)
394+
self.frame = frame_dl
395+
self.is_cartesian = isinstance(self.frame, CartesianFrameDataLoader)
396+
self.is_non_cartesian = isinstance(self.frame, NonCartesianFrameDataLoader)
397+
self.get_kspace_frame = self._get_kspace_frame
398+
399+
def __getattr__(self, name: str) -> Any:
400+
return getattr(self.frame, name)
401+
402+
@property
403+
def n_frames(self) -> int:
404+
"""Number of frames."""
405+
return self.frame.n_acquisition
406+
407+
@property
408+
def shape(self) -> tuple[int, int]:
409+
"""Shape of the volume."""
410+
return self.frame.shape[:2]
411+
412+
def _get_kspace_frame(
413+
self, idx: int, shot_dim: bool = False
414+
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
415+
"""Get the k-space frame."""
416+
n_acq_per_frame = self.frame.n_acquisition // self.frame.n_frames
417+
traj, data = self.frame.get_kspace_frame(idx // self.frame.n_shots)
418+
traj = traj.reshape(n_acq_per_frame, -1, 3)
419+
data = data.reshape(self.frame.n_coils, n_acq_per_frame, -1)
381420

421+
return traj[idx%self.frame.n_shots,:,:2], data[:,idx%self.frame.n_shots,:]
422+
382423
class CartesianFrameDataLoader(MRDLoader):
383424
"""Load cartesian MRD files k-space frames iteratively.
384425

src/snake/toolkit/cli/reconstruction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def reconstruction(cfg: DictConfig) -> None:
8181
raise ValueError("No dynamic data found matching waveform name")
8282

8383
bold_signal = good_d.data[0]
84-
bold_sample_time = np.arange(len(bold_signal)) * local_sim_conf.seq.TR / 1000
84+
bold_sample_time = np.arange(len(bold_signal)) * sim_conf.seq.TR / 1000
8585
del phantom
8686
del dyn_datas
8787
gc.collect()

src/snake/toolkit/reconstructors/fourier.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,11 @@ def init_nufft(
6060
shape = data_loader.shape
6161
traj, _ = data_loader.get_kspace_frame(0)
6262

63-
if data_loader.slice_2d:
64-
shape = data_loader.shape[:2]
65-
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])[0, :, :2]
6663

6764
kwargs = dict(
6865
shape=shape,
6966
n_coils=data_loader.n_coils,
70-
smaps=smaps,
67+
smaps=smaps.squeeze() if data_loader.slice_2d else smaps,
7168
)
7269
if density_compensation is False:
7370
kwargs["density"] = None

src/snake/toolkit/reconstructors/pysap.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CartesianFrameDataLoader,
1515
MRDLoader,
1616
NonCartesianFrameDataLoader,
17+
SliceDataloader,
1718
)
1819
from snake.core.parallel import (
1920
ArrayProps,
@@ -83,12 +84,13 @@ def reconstruct(
8384
) -> NDArray:
8485
"""Reconstruct data with zero-filled method."""
8586
with data_loader:
86-
if isinstance(data_loader, CartesianFrameDataLoader):
87-
return self._reconstruct_cartesian(data_loader)
88-
elif isinstance(data_loader, NonCartesianFrameDataLoader):
89-
return self._reconstruct_nufft(data_loader)
87+
if isinstance(data_loader, SliceDataloader):
88+
reconstruct_method = self._reconstruct_nufft if data_loader.is_non_cartesian else self._reconstruct_cartesian
89+
elif isinstance(data_loader, CartesianFrameDataLoader | NonCartesianFrameDataLoader):
90+
reconstruct_method = self._reconstruct_cartesian if isinstance(data_loader, CartesianFrameDataLoader) else self._reconstruct_nufft
9091
else:
9192
raise ValueError("Unknown dataloader")
93+
return reconstruct_method(data_loader)
9294

9395
def _reconstruct_cartesian(
9496
self,
@@ -144,19 +146,19 @@ def _reconstruct_nufft(
144146
final_images = np.empty(
145147
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
146148
)
149+
smaps = data_loader.get_smaps()
147150

148151
for i in tqdm(range(data_loader.n_frames)):
149152
traj, data = data_loader.get_kspace_frame(i)
150-
if data_loader.slice_2d:
151-
nufft_operator.samples = traj.reshape(
152-
data_loader.n_shots, -1, traj.shape[-1]
153-
)[0, :, :2]
154-
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
155-
for j in range(data.shape[1]):
156-
final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))
157-
else:
158-
nufft_operator.samples = traj
159-
final_images[i] = abs(nufft_operator.adj_op(data))
153+
#fix: density compensation should update every frame when rotating(writer: dyn_traj?)
154+
if smaps is not None and data_loader.slice_2d:
155+
nufft_operator.smaps = smaps[...,i % data_loader.frame.n_shots]
156+
nufft_operator.samples = traj
157+
final_images[i] = abs(nufft_operator.adj_op(data))
158+
if data_loader.slice_2d:
159+
final_images = np.moveaxis(final_images.reshape(
160+
(data_loader.frame.n_frames, -1, *final_images.shape[-2:])
161+
), 1, -1)
160162
return final_images
161163

162164

@@ -271,7 +273,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
271273
samples=traj,
272274
shape=data_loader.shape,
273275
n_coils=data_loader.n_coils,
274-
smaps=smaps,
276+
smaps=smaps.squeeze() if data_loader.slice_2d else smaps,
275277
# smaps=xp.array(smaps) if smaps is not None else None,
276278
density=density_compensation,
277279
squeeze_dims=True,
@@ -331,6 +333,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
331333

332334
pbar_frames.update(1)
333335
if self.restart_strategy != RestartStrategy.REFINE:
336+
if data_loader.slice_2d:
337+
final_estimate = np.moveaxis(final_estimate.reshape(
338+
(data_loader.frame.n_frames, -1, *final_estimate.shape[-2:])
339+
), 1, -1)
334340
return final_estimate
335341
# else, we do a second pass on the data using the last iteration as a slotion.
336342
pbar_frames.reset()
@@ -353,6 +359,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
353359
else:
354360
final_estimate[i, ...] = abs(x_iter)
355361
pbar_frames.update(1)
362+
if data_loader.slice_2d:
363+
final_estimate = np.moveaxis(final_estimate.reshape(
364+
(data_loader.frame.n_frames, -1, *final_estimate.shape[-2:])
365+
), 1, -1)
356366
return final_estimate
357367

358368
def _reconstruct_frame(
@@ -373,7 +383,7 @@ def _reconstruct_frame(
373383
grad_op=grad_op,
374384
linear_op=copy.deepcopy(self.space_linear_op),
375385
prox_op=copy.deepcopy(self.space_prox_op),
376-
x_init=x_init,
386+
x_init=x_init.get(),
377387
synthesis_init=False,
378388
metric_kwargs={},
379389
compute_backend=self.compute_backend,

0 commit comments

Comments
 (0)