Skip to content

Commit 06a0716

Browse files
committed
draft of 2d SliceDataloader
1 parent cbf7751 commit 06a0716

File tree

5 files changed

+119
-64
lines changed

5 files changed

+119
-64
lines changed

src/snake/mrd_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
CartesianFrameDataLoader,
55
MRDLoader,
66
NonCartesianFrameDataLoader,
7+
SliceDataloader,
78
parse_sim_conf,
89
parse_waveform_information,
910
read_mrd_header,
@@ -16,6 +17,7 @@
1617
"MRDLoader",
1718
"CartesianFrameDataLoader",
1819
"NonCartesianFrameDataLoader",
20+
"SliceDataloader",
1921
"parse_sim_conf",
2022
"parse_waveform_information",
2123
"make_base_mrd",

src/snake/mrd_utils/loader.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ class NonCartesianFrameDataLoader(MRDLoader):
366366
... image = nufft.adj_op(kspace)
367367
"""
368368

369+
369370
def get_kspace_frame(
370371
self, idx: int, shot_dim: bool = False
371372
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
@@ -402,7 +403,48 @@ def get_kspace_frame(
402403
)
403404
return traj, data
404405

406+
class SliceDataloader(MRDLoader):
407+
"""Load slice MRD files k-space frames iteratively."""
408+
409+
def __init__(self,
410+
frame_dl: MRDLoader,
411+
):
412+
super().__init__(
413+
filename=frame_dl._filename,
414+
dataset_name=frame_dl._dataset_name,
415+
writeable=frame_dl._writeable,
416+
swmr=frame_dl._swmr,
417+
)
418+
self.__class__ = type("SliceDataloader", (type(frame_dl), SliceDataloader), {})
419+
#replace by adding an attribute of if_cartesian, also need to change on pysap
420+
self.frame = frame_dl
421+
self.is_cartesian = isinstance(frame_dl, CartesianFrameDataLoader)
422+
self.get_kspace_frame = self._get_kspace_frame
423+
424+
def __getattr__(self, name: str) -> Any:
425+
return getattr(self.frame, name)
426+
427+
@property
428+
def n_frames(self) -> int:
429+
"""Number of frames."""
430+
return self.frame.n_acquisition
431+
432+
@property
433+
def shape(self) -> tuple[int, int]:
434+
"""Shape of the volume."""
435+
return self.frame.shape[:2]
436+
437+
def _get_kspace_frame(
438+
self, idx: int,
439+
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
440+
"""Get the k-space frame."""
441+
n_acq_per_frame = self.frame.n_acquisition // self.frame.n_frames
442+
traj, data = self.frame.get_kspace_frame(idx // self.frame.n_shots)
443+
traj = traj.reshape(n_acq_per_frame, -1, 3)
444+
data = data.reshape(self.frame.n_coils, n_acq_per_frame, -1)
405445

446+
return traj[idx%self.frame.n_shots,:,:2], data[:,idx%self.frame.n_shots,:]
447+
406448
def parse_sim_conf(header: mrd.xsd.ismrmrdHeader) -> SimConfig:
407449
"""Parse the header to populate SimConfig from an MRD Header."""
408450
from ..core import GreConfig, HardwareConfig, SimConfig

src/snake/toolkit/cli/reconstruction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@ def reconstruction(cfg: DictConfig) -> None:
4949
elif engine == "NUFFT":
5050
DataLoader = NonCartesianFrameDataLoader
5151

52+
5253
# Reconstructor.setup(sim_conf) # initialize operators
5354
# array = Reconstructor.reconstruct(dataloader, sim_conf)
5455
with DataLoader(cfg.filename) as data_loader:
56+
phantom = data_loader.get_phantom()
57+
dyn_datas = data_loader.get_all_dynamic()
58+
#data_loader = SliceDataloader(data_loader) if cfg.engine.slice_2d else data_loader
5559
for name, rec in cfg.reconstructors.items():
5660
rec_str = str(rec) # FIXME Also use parameters of reconstructors
5761
data_rec_file = Path(f"data_rec_{rec_str}.npy")
@@ -63,9 +67,7 @@ def reconstruction(cfg: DictConfig) -> None:
6367
np.save(data_rec_file, rec_data)
6468
log.info(f"Saved to {data_rec_file.resolve()}")
6569

66-
phantom = data_loader.get_phantom()
6770
roi_mask = phantom.masks[phantom.labels == cfg.stats.roi_tissue_name]
68-
dyn_datas = data_loader.get_all_dynamic()
6971
waveform_name = f"activation-{cfg.stats.event_name}"
7072
good_d = None
7173
for d in dyn_datas:

src/snake/toolkit/reconstructors/fourier.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from numpy.typing import NDArray
44
import scipy as sp
5-
from snake.mrd_utils.loader import CartesianFrameDataLoader, NonCartesianFrameDataLoader
5+
from snake.mrd_utils.loader import CartesianFrameDataLoader, NonCartesianFrameDataLoader, MRDLoader
66

77

88
def fft(image: NDArray, axis: int | tuple[int] = -1) -> NDArray:
@@ -53,15 +53,13 @@ def init_nufft(
5353
density_compensation: bool = False,
5454
):
5555
from mrinufft import get_operator
56-
57-
smaps = data_loader.get_smaps()
56+
if data_loader.n_coils > 1:
57+
smaps = data_loader.get_smaps().squeeze()
58+
else:
59+
smaps = None
5860
shape = data_loader.shape
5961
traj, _ = data_loader.get_kspace_frame(0)
6062

61-
if data_loader.slice_2d:
62-
shape = data_loader.shape[:2]
63-
traj = traj.reshape(data_loader.n_shots, -1, traj.shape[-1])[0, :, :2]
64-
6563
kwargs = dict(
6664
shape=shape,
6765
n_coils=data_loader.n_coils,

src/snake/toolkit/reconstructors/pysap.py

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from .base import BaseReconstructor
2929
from .fourier import ifft, init_nufft
30-
30+
from mrinufft.density.utils import get_density
3131

3232
def _reconstruct_cartesian_frame(
3333
filename: os.PathLike,
@@ -42,7 +42,7 @@ def _reconstruct_cartesian_frame(
4242
):
4343
mask, kspace = data_loader.get_kspace_frame(idx)
4444
if data_loader.slice_2d:
45-
axes = (-2,-1)
45+
axes = (-2, -1)
4646
else:
4747
axes = tuple(range(len(data_loader.shape), 0, -1))
4848
adj_data = ifft(kspace, axis=axes)
@@ -140,22 +140,37 @@ def _reconstruct_nufft(
140140
nufft_operator = init_nufft(
141141
data_loader, self.nufft_backend, self.density_compensation
142142
)
143+
144+
method = self.density_compensation
145+
146+
if isinstance(method, str):
147+
method = get_density(method)
148+
if not callable(method):
149+
raise ValueError(f"Unknown density method: {method}")
150+
143151
final_images = np.empty(
144152
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
145153
)
154+
smaps = data_loader.get_smaps()
146155

147156
for i in tqdm(range(data_loader.n_frames)):
148157
traj, data = data_loader.get_kspace_frame(i)
149-
if data_loader.slice_2d:
150-
nufft_operator.samples = traj.reshape(
151-
data_loader.n_shots, -1, traj.shape[-1]
152-
)[0, :, :2]
153-
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
154-
for j in range(data.shape[1]):
155-
final_images[i, :, :, j] = abs(nufft_operator.adj_op(data[:, j]))
156-
else:
157-
nufft_operator.samples = traj
158-
final_images[i] = abs(nufft_operator.adj_op(data))
158+
159+
160+
# if self.density_compensation is True : #fix: it should update every frame when rotating
161+
# nufft_operator.density = method(
162+
# traj, data_loader.shape, backend=self.nufft_backend
163+
# )
164+
165+
if smaps is not None:
166+
nufft_operator.smaps = smaps[...,i % data_loader.frame.n_shots]
167+
168+
nufft_operator.samples = traj
169+
final_images[i] = abs(nufft_operator.adj_op(data))
170+
if data_loader.slice_2d:
171+
final_images = np.moveaxis(final_images.reshape(
172+
(data_loader.frame.n_frames, -1, *final_images.shape[-2:])
173+
), 1, -1)
159174
return final_images
160175

161176

@@ -212,7 +227,7 @@ def setup(self, sim_conf: SimConfig = None, shape: tuple[int] = None) -> None:
212227

213228
self.space_linear_op = WaveletTransform(
214229
self.wavelet,
215-
shape=shape, ##### shape[:2] if data_loader.slice_2d
230+
shape=shape,
216231
level=3,
217232
mode="zero",
218233
compute_backend=self.compute_backend,
@@ -233,11 +248,10 @@ def setup(self, sim_conf: SimConfig = None, shape: tuple[int] = None) -> None:
233248
linear=Identity(), weights=self.threshold
234249
)
235250

236-
def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
251+
def reconstruct(self, data_loader: NonCartesianFrameDataLoader) -> np.ndarray:
237252
"""Reconstruct with Sequential."""
238253
shape = data_loader.shape
239-
if data_loader.slice_2d:
240-
shape = shape[:2]
254+
241255
self.setup(shape=shape)
242256
from fmri.operators.gradient import (
243257
GradAnalysis,
@@ -250,13 +264,9 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
250264
xp, _ = get_backend(self.compute_backend)
251265

252266
traj, data = data_loader.get_kspace_frame(0)
253-
if data_loader.slice_2d:
254-
traj = np.reshape(
255-
traj,(data_loader.n_shots, -1, traj.shape[-1])
256-
)[0, :, :2]
257-
data = np.reshape(data, (data_loader.n_shots, -1))
258-
259-
smaps = data_loader.get_smaps()
267+
268+
269+
smaps = data_loader.get_smaps()
260270

261271
density_compensation = self.density_compensation
262272
if (
@@ -273,10 +283,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
273283

274284
fourier_op = get_operator(
275285
self.nufft_backend,
276-
samples=traj,
277-
shape=shape,##### shape[:2] if 2D
286+
samples=traj,
287+
shape=shape, ##### shape[:2] if 2D
278288
n_coils=data_loader.n_coils,
279-
smaps=smaps,
289+
smaps=smaps.squeeze(),
280290
# smaps=xp.array(smaps) if smaps is not None else None,
281291
density=density_compensation,
282292
**kwargs,
@@ -307,32 +317,33 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
307317
x_init = fourier_op.adj_op(xp.array(data * density_comp_vector, copy=False))
308318
else:
309319
if data_loader.slice_2d:
310-
for i in range (data.shape[0]):
311-
x_init[...,i] = fourier_op.adj_op(
312-
xp.array(data[i][None,...], copy=False)
313-
)
320+
data = np.reshape(data, (data_loader.n_shots, data_loader.n_coils, -1))
321+
for i in range(data.shape[0]):
322+
x_init[..., i] = fourier_op.adj_op(
323+
xp.array(data[i][None, ...], copy=False)
324+
)
314325

315326
pbar_frames = tqdm(total=data_loader.n_frames, position=0)
316327
pbar_iter = tqdm(total=self.max_iter_per_frame, position=1)
317-
x_iter = x_init.copy()
318-
for i, traj, data in data_loader.iter_frames():
319-
#grad_op.fourier_op.samples = traj
328+
x_iter = x_init.copy()
329+
for i, traj, data in data_loader.iter_frames(): #iter_slice
330+
# grad_op.fourier_op.samples = traj
320331
spec_rad = grad_op.fourier_op.get_lipschitz_cst(20)
321-
#grad_op._obs_data = xp.array(data)
332+
# grad_op._obs_data = xp.array(data)
322333
grad_op.spec_rad = spec_rad
323334
grad_op.inv_spec_rad = 1 / spec_rad
324335
# pseudo code
325-
if data_loader.slice_2d:
326-
traj = np.reshape(
327-
traj,(data_loader.n_shots, -1, traj.shape[-1])
328-
)[0, :, :2]
329-
data = np.reshape(data, ( data_loader.n_shots, -1))
336+
if data_loader.slice_2d:
337+
traj = np.reshape(traj, (data_loader.n_shots, -1, traj.shape[-1]))[
338+
0, :, :2
339+
]
340+
data = np.reshape(data, (data_loader.n_shots, data_loader.n_coils, -1))
330341
grad_op.fourier_op.samples = traj
331342
for j in range(data.shape[0]):
332-
grad_op._obs_data = xp.array(data[j][None,...])
333-
x_iter[...,j] = self._reconstruct_frame(
334-
grad_op, x_init[...,j], n_iter=self.max_iter_per_frame
335-
)
343+
grad_op._obs_data = xp.array(data[j][None, ...])
344+
x_iter[..., j] = self._reconstruct_frame(
345+
grad_op, x_init[..., j], n_iter=self.max_iter_per_frame
346+
)
336347
# Prepare for next iteration and save results
337348
x_init = (
338349
x_iter.copy()
@@ -349,26 +360,26 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
349360
return final_estimate
350361
# else, we do a second pass on the data using the last iteration as a solution.
351362
pbar_frames.reset()
352-
#pbar_iter.reset()
363+
# pbar_iter.reset()
353364
x_init = x_iter.copy() # last iteration results.
354365
for i, traj, data in data_loader.iter_frames():
355-
#grad_op.fourier_op.samples = traj
366+
# grad_op.fourier_op.samples = traj
356367
spec_rad = grad_op.fourier_op.get_lipschitz_cst()
357-
#grad_op._obs_data = xp.array(data)
368+
# grad_op._obs_data = xp.array(data)
358369
grad_op.spec_rad = spec_rad
359370
grad_op.inv_spec_rad = 1 / spec_rad
360371
# pseudo code
361-
if data_loader.slice_2d:
362-
traj = np.reshape(
363-
traj,(data_loader.n_shots, -1, traj.shape[-1])
364-
)[0, :, :2]
365-
data = np.reshape(data, (data_loader.n_shots, -1))
372+
if data_loader.slice_2d:
373+
traj = np.reshape(traj, (data_loader.n_shots, -1, traj.shape[-1]))[
374+
0, :, :2
375+
]
376+
data = np.reshape(data, (data_loader.n_shots, data_loader.n_coils, -1))
366377
grad_op.fourier_op.samples = traj
367378
for i in range(data.shape[0]):
368379
grad_op._obs_data = xp.array(data[i])
369-
x_iter[...,i] = self._reconstruct_frame(
370-
grad_op, x_init[:,:,i], n_iter=self.max_iter_per_frame
371-
)
380+
x_iter[..., i] = self._reconstruct_frame(
381+
grad_op, x_init[:, :, i], n_iter=self.max_iter_per_frame
382+
)
372383
if self.compute_backend == "cupy":
373384
final_estimate[i, ...] = abs(x_iter).get() # type: ignore
374385
else:
@@ -411,4 +422,4 @@ def _reconstruct_frame(
411422
img = grad_op.linear_op.adj_op(opt.x_final)
412423
else:
413424
img = opt.x_final
414-
return img
425+
return img

0 commit comments

Comments
 (0)