@@ -73,18 +73,20 @@ def setup(self, sim_conf: SimConfig) -> None:
7373 """Initialize Reconstructor."""
7474 pass
7575
76- def reconstruct (self , data_loader : MRDLoader , sim_conf : SimConfig ) -> NDArray :
76+ def reconstruct (
77+ self , data_loader : MRDLoader , sim_conf : SimConfig , slice_2d : bool
78+ ) -> NDArray :
7779 """Reconstruct data with zero-filled method."""
7880 with data_loader :
7981 if isinstance (data_loader , CartesianFrameDataLoader ):
80- return self ._reconstruct_cartesian (data_loader , sim_conf )
82+ return self ._reconstruct_cartesian (data_loader , sim_conf , slice_2d )
8183 elif isinstance (data_loader , NonCartesianFrameDataLoader ):
82- return self ._reconstruct_nufft (data_loader , sim_conf )
84+ return self ._reconstruct_nufft (data_loader , sim_conf , slice_2d )
8385 else :
8486 raise ValueError ("Unknown dataloader" )
8587
8688 def _reconstruct_cartesian (
87- self , data_loader : CartesianFrameDataLoader , sim_conf : SimConfig
89+ self , data_loader : CartesianFrameDataLoader , sim_conf : SimConfig , slice_2d
8890 ) -> NDArray :
8991 smaps = data_loader .get_smaps ()
9092 if smaps is None and data_loader .n_coils > 1 :
@@ -114,7 +116,6 @@ def _reconstruct_cartesian(
114116 ): idx
115117 for idx in range (data_loader .n_frames )
116118 }
117-
118119 for future in as_completed (futures ):
119120 future .result ()
120121 pbar .update (1 )
@@ -126,16 +127,21 @@ def _reconstruct_cartesian(
126127 return final_images
127128
128129 def _reconstruct_nufft (
129- self , data_loader : NonCartesianFrameDataLoader , sim_conf : SimConfig
130+ self , data_loader : NonCartesianFrameDataLoader , sim_conf : SimConfig , slice_2d
130131 ) -> NDArray :
131132 """Reconstruct data with nufft method."""
132133 from mrinufft import get_operator
133134
134135 smaps = data_loader .get_smaps ()
135-
136+ shape = data_loader . shape
136137 traj , kspace_data = data_loader .get_kspace_frame (0 )
138+
139+ if slice_2d :
140+ shape = data_loader .shape [:2 ]
141+ traj = traj .reshape (data_loader .n_shots , - 1 , traj .shape [- 1 ])[0 , :, :2 ]
142+
137143 kwargs = dict (
138- shape = data_loader . shape ,
144+ shape = shape ,
139145 n_coils = data_loader .n_coils ,
140146 smaps = smaps ,
141147 )
@@ -146,6 +152,7 @@ def _reconstruct_nufft(
146152 kwargs ["density" ] = self .density_compensation
147153 if "stacked" in self .nufft_backend :
148154 kwargs ["z_index" ] = "auto"
155+
149156 nufft_operator = get_operator (
150157 self .nufft_backend ,
151158 samples = traj ,
@@ -158,8 +165,16 @@ def _reconstruct_nufft(
158165
159166 for i in tqdm (range (data_loader .n_frames )):
160167 traj , data = data_loader .get_kspace_frame (i )
161- nufft_operator .samples = traj
162- final_images [i ] = abs (nufft_operator .adj_op (data ))
168+ if slice_2d :
169+ nufft_operator .samples = traj .reshape (
170+ data_loader .n_shots , - 1 , traj .shape [- 1 ]
171+ )[0 , :, :2 ]
172+ data = np .reshape (data , (data .shape [0 ], data_loader .n_shots , - 1 ))
173+ for j in range (data .shape [1 ]):
174+ final_images [i , :, :, j ] = abs (nufft_operator .adj_op (data [:, j ]))
175+ else :
176+ nufft_operator .samples = traj
177+ final_images [i ] = abs (nufft_operator .adj_op (data ))
163178 return final_images
164179
165180
0 commit comments