1212
1313
1414def plot_mae_reconstruction (
15- batch : dict [str , torch .Tensor ],
16- pred_dict : dict [str , torch .Tensor ],
17- enc_mask_dict : dict [str , torch .Tensor ],
18- patch_size_dict : dict [str , tuple [int , ...]],
19- grid_size_dict : dict [str , tuple [int , ...]],
20- sax_slices : int ,
15+ image_dict : dict [str , torch .Tensor ],
16+ reconstructed_dict : dict [str , torch .Tensor ],
17+ masks_dict : dict [str , torch .Tensor ],
2118) -> plt .Figure :
2219 """Plot MAE reconstruction."""
20+ sax_slices = image_dict ["sax" ].shape [- 1 ]
2321 n_rows = sax_slices + 3
2422 n_cols = 4
2523 fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols * 2 , n_rows * 2 ), dpi = 300 )
2624 for i , view in enumerate (["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]):
27- patches = patchify (image = batch [view ], patch_size = patch_size_dict [view ])
28- patches [enc_mask_dict [view ]] = pred_dict [view ]
29- masks = torch .zeros_like (patches )
30- masks [enc_mask_dict [view ]] = 1
31- masks = unpatchify (masks , patch_size = patch_size_dict [view ], grid_size = grid_size_dict [view ])
32- masks = masks [0 , 0 ]
33- reconstructed = unpatchify (
34- patches ,
35- patch_size = patch_size_dict [view ],
36- grid_size = grid_size_dict [view ],
37- )
38- reconstructed = reconstructed [0 , 0 ].numpy ()
39- image = batch [view ][0 , 0 ].numpy ()
25+ masks = masks_dict [view ]
26+ reconstructed = reconstructed_dict [view ]
27+ image = image_dict [view ]
4028 error = np .abs (reconstructed - image )
4129
4230 if view == "sax" :
43- reconstructed = reconstructed [..., :sax_slices ]
4431 for j in range (sax_slices ):
4532 axs [3 + j , 0 ].set_ylabel (f"SAX slice { j } " )
4633 axs [3 + j , 0 ].imshow (image [..., j ], cmap = "gray" )
@@ -68,15 +55,42 @@ def plot_mae_reconstruction(
6855 return fig
6956
7057
58+ def reconstruct_images (
59+ batch : dict [str , torch .Tensor ],
60+ pred_dict : dict [str , torch .Tensor ],
61+ enc_mask_dict : dict [str , torch .Tensor ],
62+ patch_size_dict : dict [str , tuple [int , ...]],
63+ grid_size_dict : dict [str , tuple [int , ...]],
64+ sax_slices : int ,
65+ ) -> tuple [dict [str , np .ndarray ], dict [str , np .ndarray ]]:
66+ """Reconstruct images from predicted patches."""
67+ reconstructed_dict = {}
68+ masks_dict = {}
69+ for view in ["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]:
70+ patches = patchify (image = batch [view ], patch_size = patch_size_dict [view ])
71+ patches [enc_mask_dict [view ]] = pred_dict [view ]
72+ masks = torch .zeros_like (patches )
73+ masks [enc_mask_dict [view ]] = 1
74+ masks = unpatchify (masks , patch_size = patch_size_dict [view ], grid_size = grid_size_dict [view ])
75+ reconstructed = unpatchify (
76+ patches ,
77+ patch_size = patch_size_dict [view ],
78+ grid_size = grid_size_dict [view ],
79+ )
80+ reconstructed_dict [view ] = reconstructed .detach ().cpu ().numpy ()[0 , 0 ]
81+ masks_dict [view ] = masks .detach ().cpu ().numpy ()[0 , 0 ]
82+ reconstructed_dict ["sax" ] = reconstructed_dict ["sax" ][..., :sax_slices ]
83+ masks_dict ["sax" ] = masks_dict ["sax" ][..., :sax_slices ]
84+ return reconstructed_dict , masks_dict
85+
86+
7187def run (device : torch .device , dtype : torch .dtype ) -> None :
7288 """Run MAE reconstruction."""
7389 t = 25 # which time frame to use
7490
7591 # load model
7692 model = CineMA .from_pretrained ()
7793 model .eval ()
78- patch_size_dict = model .dec_patch_size_dict
79- grid_size_dict = {k : v .patch_embed .grid_size for k , v in model .enc_down_dict .items ()}
8094 model .to (device )
8195
8296 # load sample data and form a batch of size 1
@@ -95,36 +109,41 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
95109 )
96110 # (x, y, z, t) for SAX and (x, y, 1, t) for LAX
97111 exp_dir = Path (__file__ ).parent .parent .resolve ()
98- sax_image = torch .from_numpy (
99- np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_sax.nii.gz" )))
100- )
101- lax_2c_image = torch .from_numpy (
102- np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_2c.nii.gz" )))
103- )
104- lax_3c_image = torch .from_numpy (
105- np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_3c.nii.gz" )))
106- )
107- lax_4c_image = torch .from_numpy (
108- np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_4c.nii.gz" )))
109- )
110- sax_slices = sax_image .shape [- 2 ]
111- batch = {
112- "sax" : sax_image [None , ..., t ],
113- "lax_2c" : lax_2c_image [None , ..., 0 , t ],
114- "lax_3c" : lax_3c_image [None , ..., 0 , t ],
115- "lax_4c" : lax_4c_image [None , ..., 0 , t ],
112+ sax_image = np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_sax.nii.gz" )))
113+ lax_2c_image = np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_2c.nii.gz" )))
114+ lax_3c_image = np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_3c.nii.gz" )))
115+ lax_4c_image = np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_4c.nii.gz" )))
116+
117+ image_dict = {
118+ "sax" : sax_image [..., t ],
119+ "lax_2c" : lax_2c_image [..., 0 , t ],
120+ "lax_3c" : lax_3c_image [..., 0 , t ],
121+ "lax_4c" : lax_4c_image [..., 0 , t ],
116122 }
117- batch = transform (batch )
118- print (f"SAX view had originally { sax_image .shape [- 2 ]} slices, now zero-padded to { batch ['sax' ].shape [- 1 ]} slices." ) # noqa: T201
119- batch = {k : v [None , ...].to (device = device , dtype = dtype ) for k , v in batch .items ()}
123+ batch = {k : torch .from_numpy (v [None , ...]) for k , v in image_dict .items ()}
120124
121125 # forward
126+ sax_slices = batch ["sax" ].shape [- 1 ]
127+ batch = transform (batch )
128+ batch = {k : v [None , ...].to (device = device , dtype = dtype ) for k , v in batch .items ()}
122129 with torch .no_grad (), torch .autocast ("cuda" , dtype = dtype , enabled = torch .cuda .is_available ()):
123130 _ , pred_dict , enc_mask_dict , _ = model (batch , enc_mask_ratio = 0.75 )
131+ grid_size_dict = {k : v .patch_embed .grid_size for k , v in model .enc_down_dict .items ()}
132+ reconstructed_dict , masks_dict = reconstruct_images (
133+ batch ,
134+ pred_dict ,
135+ enc_mask_dict ,
136+ model .dec_patch_size_dict ,
137+ grid_size_dict ,
138+ sax_slices ,
139+ )
124140
125141 # visualize
126- batch = {k : v .detach ().cpu () for k , v in batch .items ()}
127- fig = plot_mae_reconstruction (batch , pred_dict , enc_mask_dict , patch_size_dict , grid_size_dict , sax_slices )
142+ fig = plot_mae_reconstruction (
143+ image_dict ,
144+ reconstructed_dict ,
145+ masks_dict ,
146+ )
128147 fig .savefig ("mae_reconstruction.png" , dpi = 300 , bbox_inches = "tight" )
129148 plt .show (block = False )
130149
0 commit comments