2727
2828from .base import BaseReconstructor
2929from .fourier import ifft , init_nufft
30-
30+ from mrinufft . density . utils import get_density
3131
3232def _reconstruct_cartesian_frame (
3333 filename : os .PathLike ,
@@ -42,7 +42,7 @@ def _reconstruct_cartesian_frame(
4242 ):
4343 mask , kspace = data_loader .get_kspace_frame (idx )
4444 if data_loader .slice_2d :
45- axes = (- 2 ,- 1 )
45+ axes = (- 2 , - 1 )
4646 else :
4747 axes = tuple (range (len (data_loader .shape ), 0 , - 1 ))
4848 adj_data = ifft (kspace , axis = axes )
@@ -140,22 +140,37 @@ def _reconstruct_nufft(
140140 nufft_operator = init_nufft (
141141 data_loader , self .nufft_backend , self .density_compensation
142142 )
143+
144+ method = self .density_compensation
145+
146+ if isinstance (method , str ):
147+ method = get_density (method )
148+ if not callable (method ):
149+ raise ValueError (f"Unknown density method: { method } " )
150+
143151 final_images = np .empty (
144152 (data_loader .n_frames , * data_loader .shape ), dtype = np .float32
145153 )
154+ smaps = data_loader .get_smaps ()
146155
147156 for i in tqdm (range (data_loader .n_frames )):
148157 traj , data = data_loader .get_kspace_frame (i )
149- if data_loader .slice_2d :
150- nufft_operator .samples = traj .reshape (
151- data_loader .n_shots , - 1 , traj .shape [- 1 ]
152- )[0 , :, :2 ]
153- data = np .reshape (data , (data .shape [0 ], data_loader .n_shots , - 1 ))
154- for j in range (data .shape [1 ]):
155- final_images [i , :, :, j ] = abs (nufft_operator .adj_op (data [:, j ]))
156- else :
157- nufft_operator .samples = traj
158- final_images [i ] = abs (nufft_operator .adj_op (data ))
158+
159+
160+ # if self.density_compensation is True : #fix: it should update every frame when rotating
161+ # nufft_operator.density = method(
162+ # traj, data_loader.shape, backend=self.nufft_backend
163+ # )
164+
165+ if smaps is not None :
166+ nufft_operator .smaps = smaps [...,i % data_loader .frame .n_shots ]
167+
168+ nufft_operator .samples = traj
169+ final_images [i ] = abs (nufft_operator .adj_op (data ))
170+ if data_loader .slice_2d :
171+ final_images = np .moveaxis (final_images .reshape (
172+ (data_loader .frame .n_frames , - 1 , * final_images .shape [- 2 :])
173+ ), 1 , - 1 )
159174 return final_images
160175
161176
@@ -212,7 +227,7 @@ def setup(self, sim_conf: SimConfig = None, shape: tuple[int] = None) -> None:
212227
213228 self .space_linear_op = WaveletTransform (
214229 self .wavelet ,
215- shape = shape , ##### shape[:2] if data_loader.slice_2d
230+ shape = shape ,
216231 level = 3 ,
217232 mode = "zero" ,
218233 compute_backend = self .compute_backend ,
@@ -233,11 +248,10 @@ def setup(self, sim_conf: SimConfig = None, shape: tuple[int] = None) -> None:
233248 linear = Identity (), weights = self .threshold
234249 )
235250
236- def reconstruct (self , data_loader : MRDLoader ) -> np .ndarray :
251+ def reconstruct (self , data_loader : NonCartesianFrameDataLoader ) -> np .ndarray :
237252 """Reconstruct with Sequential."""
238253 shape = data_loader .shape
239- if data_loader .slice_2d :
240- shape = shape [:2 ]
254+
241255 self .setup (shape = shape )
242256 from fmri .operators .gradient import (
243257 GradAnalysis ,
@@ -250,13 +264,9 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
250264 xp , _ = get_backend (self .compute_backend )
251265
252266 traj , data = data_loader .get_kspace_frame (0 )
253- if data_loader .slice_2d :
254- traj = np .reshape (
255- traj ,(data_loader .n_shots , - 1 , traj .shape [- 1 ])
256- )[0 , :, :2 ]
257- data = np .reshape (data , (data_loader .n_shots , - 1 ))
258-
259- smaps = data_loader .get_smaps ()
267+
268+
269+ smaps = data_loader .get_smaps ()
260270
261271 density_compensation = self .density_compensation
262272 if (
@@ -273,10 +283,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
273283
274284 fourier_op = get_operator (
275285 self .nufft_backend ,
276- samples = traj ,
277- shape = shape ,##### shape[:2] if 2D
286+ samples = traj ,
287+ shape = shape , ##### shape[:2] if 2D
278288 n_coils = data_loader .n_coils ,
279- smaps = smaps ,
289+ smaps = smaps . squeeze () ,
280290 # smaps=xp.array(smaps) if smaps is not None else None,
281291 density = density_compensation ,
282292 ** kwargs ,
@@ -307,32 +317,33 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
307317 x_init = fourier_op .adj_op (xp .array (data * density_comp_vector , copy = False ))
308318 else :
309319 if data_loader .slice_2d :
310- for i in range (data .shape [0 ]):
311- x_init [...,i ] = fourier_op .adj_op (
312- xp .array (data [i ][None ,...], copy = False )
313- )
320+ data = np .reshape (data , (data_loader .n_shots , data_loader .n_coils , - 1 ))
321+ for i in range (data .shape [0 ]):
322+ x_init [..., i ] = fourier_op .adj_op (
323+ xp .array (data [i ][None , ...], copy = False )
324+ )
314325
315326 pbar_frames = tqdm (total = data_loader .n_frames , position = 0 )
316327 pbar_iter = tqdm (total = self .max_iter_per_frame , position = 1 )
317- x_iter = x_init .copy ()
318- for i , traj , data in data_loader .iter_frames ():
319- #grad_op.fourier_op.samples = traj
328+ x_iter = x_init .copy ()
329+ for i , traj , data in data_loader .iter_frames (): #iter_slice
330+ # grad_op.fourier_op.samples = traj
320331 spec_rad = grad_op .fourier_op .get_lipschitz_cst (20 )
321- #grad_op._obs_data = xp.array(data)
332+ # grad_op._obs_data = xp.array(data)
322333 grad_op .spec_rad = spec_rad
323334 grad_op .inv_spec_rad = 1 / spec_rad
324335 # pseudo code
325- if data_loader .slice_2d :
326- traj = np .reshape (
327- traj ,( data_loader . n_shots , - 1 , traj . shape [ - 1 ])
328- )[ 0 , :, : 2 ]
329- data = np .reshape (data , ( data_loader .n_shots , - 1 ))
336+ if data_loader .slice_2d :
337+ traj = np .reshape (traj , ( data_loader . n_shots , - 1 , traj . shape [ - 1 ]))[
338+ 0 , :, : 2
339+ ]
340+ data = np .reshape (data , (data_loader .n_shots , data_loader . n_coils , - 1 ))
330341 grad_op .fourier_op .samples = traj
331342 for j in range (data .shape [0 ]):
332- grad_op ._obs_data = xp .array (data [j ][None ,...])
333- x_iter [...,j ] = self ._reconstruct_frame (
334- grad_op , x_init [...,j ], n_iter = self .max_iter_per_frame
335- )
343+ grad_op ._obs_data = xp .array (data [j ][None , ...])
344+ x_iter [..., j ] = self ._reconstruct_frame (
345+ grad_op , x_init [..., j ], n_iter = self .max_iter_per_frame
346+ )
336347 # Prepare for next iteration and save results
337348 x_init = (
338349 x_iter .copy ()
@@ -349,26 +360,26 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
349360 return final_estimate
350361 # else, we do a second pass on the data using the last iteration as a solution.
351362 pbar_frames .reset ()
352- #pbar_iter.reset()
363+ # pbar_iter.reset()
353364 x_init = x_iter .copy () # last iteration results.
354365 for i , traj , data in data_loader .iter_frames ():
355- #grad_op.fourier_op.samples = traj
366+ # grad_op.fourier_op.samples = traj
356367 spec_rad = grad_op .fourier_op .get_lipschitz_cst ()
357- #grad_op._obs_data = xp.array(data)
368+ # grad_op._obs_data = xp.array(data)
358369 grad_op .spec_rad = spec_rad
359370 grad_op .inv_spec_rad = 1 / spec_rad
360371 # pseudo code
361- if data_loader .slice_2d :
362- traj = np .reshape (
363- traj ,( data_loader . n_shots , - 1 , traj . shape [ - 1 ])
364- )[ 0 , :, : 2 ]
365- data = np .reshape (data , (data_loader .n_shots , - 1 ))
372+ if data_loader .slice_2d :
373+ traj = np .reshape (traj , ( data_loader . n_shots , - 1 , traj . shape [ - 1 ]))[
374+ 0 , :, : 2
375+ ]
376+ data = np .reshape (data , (data_loader .n_shots , data_loader . n_coils , - 1 ))
366377 grad_op .fourier_op .samples = traj
367378 for i in range (data .shape [0 ]):
368379 grad_op ._obs_data = xp .array (data [i ])
369- x_iter [...,i ] = self ._reconstruct_frame (
370- grad_op , x_init [:,:, i ], n_iter = self .max_iter_per_frame
371- )
380+ x_iter [..., i ] = self ._reconstruct_frame (
381+ grad_op , x_init [:, :, i ], n_iter = self .max_iter_per_frame
382+ )
372383 if self .compute_backend == "cupy" :
373384 final_estimate [i , ...] = abs (x_iter ).get () # type: ignore
374385 else :
@@ -411,4 +422,4 @@ def _reconstruct_frame(
411422 img = grad_op .linear_op .adj_op (opt .x_final )
412423 else :
413424 img = opt .x_final
414- return img
425+ return img
0 commit comments