1414 CartesianFrameDataLoader ,
1515 MRDLoader ,
1616 NonCartesianFrameDataLoader ,
17+ SliceDataloader ,
1718)
1819from snake .core .parallel import (
1920 ArrayProps ,
@@ -83,12 +84,13 @@ def reconstruct(
8384 ) -> NDArray :
8485 """Reconstruct data with zero-filled method."""
8586 with data_loader :
86- if isinstance (data_loader , CartesianFrameDataLoader ):
87- return self ._reconstruct_cartesian ( data_loader )
88- elif isinstance (data_loader , NonCartesianFrameDataLoader ):
89- return self ._reconstruct_nufft (data_loader )
87+ if isinstance (data_loader , SliceDataloader ):
88+ reconstruct_method = self ._reconstruct_nufft if data_loader . is_non_cartesian else self . _reconstruct_cartesian
89+ elif isinstance (data_loader , CartesianFrameDataLoader | NonCartesianFrameDataLoader ):
90+ reconstruct_method = self ._reconstruct_cartesian if isinstance (data_loader , CartesianFrameDataLoader ) else self . _reconstruct_nufft
9091 else :
9192 raise ValueError ("Unknown dataloader" )
93+ return reconstruct_method (data_loader )
9294
9395 def _reconstruct_cartesian (
9496 self ,
@@ -144,19 +146,19 @@ def _reconstruct_nufft(
144146 final_images = np .empty (
145147 (data_loader .n_frames , * data_loader .shape ), dtype = np .float32
146148 )
149+ smaps = data_loader .get_smaps ()
147150
148151 for i in tqdm (range (data_loader .n_frames )):
149152 traj , data = data_loader .get_kspace_frame (i )
150- if data_loader .slice_2d :
151- nufft_operator .samples = traj .reshape (
152- data_loader .n_shots , - 1 , traj .shape [- 1 ]
153- )[0 , :, :2 ]
154- data = np .reshape (data , (data .shape [0 ], data_loader .n_shots , - 1 ))
155- for j in range (data .shape [1 ]):
156- final_images [i , :, :, j ] = abs (nufft_operator .adj_op (data [:, j ]))
157- else :
158- nufft_operator .samples = traj
159- final_images [i ] = abs (nufft_operator .adj_op (data ))
153+ #fix: density compensation should update every frame when rotating(writer: dyn_traj?)
154+ if smaps is not None and data_loader .slice_2d :
155+ nufft_operator .smaps = smaps [...,i % data_loader .frame .n_shots ]
156+ nufft_operator .samples = traj
157+ final_images [i ] = abs (nufft_operator .adj_op (data ))
158+ if data_loader .slice_2d :
159+ final_images = np .moveaxis (final_images .reshape (
160+ (data_loader .frame .n_frames , - 1 , * final_images .shape [- 2 :])
161+ ), 1 , - 1 )
160162 return final_images
161163
162164
@@ -271,7 +273,7 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
271273 samples = traj ,
272274 shape = data_loader .shape ,
273275 n_coils = data_loader .n_coils ,
274- smaps = smaps ,
276+ smaps = smaps . squeeze () if data_loader . slice_2d else smaps ,
275277 # smaps=xp.array(smaps) if smaps is not None else None,
276278 density = density_compensation ,
277279 squeeze_dims = True ,
@@ -331,6 +333,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
331333
332334 pbar_frames .update (1 )
333335 if self .restart_strategy != RestartStrategy .REFINE :
336+ if data_loader .slice_2d :
337+ final_estimate = np .moveaxis (final_estimate .reshape (
338+ (data_loader .frame .n_frames , - 1 , * final_estimate .shape [- 2 :])
339+ ), 1 , - 1 )
334340 return final_estimate
335341 # else, we do a second pass on the data using the last iteration as a slotion.
336342 pbar_frames .reset ()
@@ -353,6 +359,10 @@ def reconstruct(self, data_loader: MRDLoader) -> np.ndarray:
353359 else :
354360 final_estimate [i , ...] = abs (x_iter )
355361 pbar_frames .update (1 )
362+ if data_loader .slice_2d :
363+ final_estimate = np .moveaxis (final_estimate .reshape (
364+ (data_loader .frame .n_frames , - 1 , * final_estimate .shape [- 2 :])
365+ ), 1 , - 1 )
356366 return final_estimate
357367
358368 def _reconstruct_frame (
@@ -373,7 +383,7 @@ def _reconstruct_frame(
373383 grad_op = grad_op ,
374384 linear_op = copy .deepcopy (self .space_linear_op ),
375385 prox_op = copy .deepcopy (self .space_prox_op ),
376- x_init = x_init ,
386+ x_init = x_init . get () ,
377387 synthesis_init = False ,
378388 metric_kwargs = {},
379389 compute_backend = self .compute_backend ,
0 commit comments