@@ -77,8 +77,8 @@ def reconstruct_images(
7777 patch_size = patch_size_dict [view ],
7878 grid_size = grid_size_dict [view ],
7979 )
80- reconstructed_dict [view ] = reconstructed .detach ().cpu ().numpy ()[0 , 0 ]
81- masks_dict [view ] = masks .detach ().cpu ().numpy ()[0 , 0 ]
80+ reconstructed_dict [view ] = reconstructed .detach ().to ( torch . float32 ). cpu ().numpy ()[0 , 0 ]
81+ masks_dict [view ] = masks .detach ().to ( torch . float32 ). cpu ().numpy ()[0 , 0 ]
8282 reconstructed_dict ["sax" ] = reconstructed_dict ["sax" ][..., :sax_slices ]
8383 masks_dict ["sax" ] = masks_dict ["sax" ][..., :sax_slices ]
8484 return reconstructed_dict , masks_dict
@@ -115,12 +115,12 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
115115 lax_4c_image = np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_4c.nii.gz" )))
116116
117117 image_dict = {
118- "sax" : sax_image [..., t ],
119- "lax_2c" : lax_2c_image [..., 0 , t ],
120- "lax_3c" : lax_3c_image [..., 0 , t ],
121- "lax_4c" : lax_4c_image [..., 0 , t ],
118+ "sax" : sax_image [None , ..., t ],
119+ "lax_2c" : lax_2c_image [None , ..., 0 , t ],
120+ "lax_3c" : lax_3c_image [None , ..., 0 , t ],
121+ "lax_4c" : lax_4c_image [None , ..., 0 , t ],
122122 }
123- batch = {k : torch .from_numpy (v [ None , ...] ) for k , v in image_dict .items ()}
123+ batch = {k : torch .from_numpy (v ) for k , v in image_dict .items ()}
124124
125125 # forward
126126 sax_slices = batch ["sax" ].shape [- 1 ]
@@ -137,10 +137,12 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
137137 grid_size_dict ,
138138 sax_slices ,
139139 )
140+ batch = {k : v .detach ().to (torch .float32 ).cpu ().numpy ()[0 , 0 ] for k , v in batch .items ()}
141+ batch ["sax" ] = batch ["sax" ][..., :sax_slices ]
140142
141143 # visualize
142144 fig = plot_mae_reconstruction (
143- image_dict ,
145+ batch ,
144146 reconstructed_dict ,
145147 masks_dict ,
146148 )
0 commit comments