1111from cinema import CineMA , patchify , unpatchify
1212
1313
14+ def plot_mae_reconstruction (
15+ model : CineMA ,
16+ batch : dict [str , torch .Tensor ],
17+ pred_dict : dict [str , torch .Tensor ],
18+ enc_mask_dict : dict [str , torch .Tensor ],
19+ sax_slices : int ,
20+ ) -> plt .Figure :
21+ """Plot MAE reconstruction."""
22+ n_rows = sax_slices + 3
23+ n_cols = 4
24+ fig , axs = plt .subplots (n_rows , n_cols , figsize = (n_cols * 2 , n_rows * 2 ), dpi = 300 )
25+ for i , view in enumerate (["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]):
26+ patches = patchify (image = batch [view ], patch_size = model .dec_patch_size_dict [view ])
27+ patches [enc_mask_dict [view ]] = pred_dict [view ]
28+ masks = torch .zeros_like (patches )
29+ masks [enc_mask_dict [view ]] = 1
30+ masks = unpatchify (
31+ masks , patch_size = model .dec_patch_size_dict [view ], grid_size = model .enc_down_dict [view ].patch_embed .grid_size
32+ )
33+ masks = masks [0 , 0 ]
34+ reconstructed = unpatchify (
35+ patches ,
36+ patch_size = model .dec_patch_size_dict [view ],
37+ grid_size = model .enc_down_dict [view ].patch_embed .grid_size ,
38+ )
39+ reconstructed = reconstructed [0 , 0 ].detach ().cpu ().numpy ()
40+ image = batch [view ][0 , 0 ].detach ().cpu ().numpy ()
41+ error = np .abs (reconstructed - image )
42+
43+ if view == "sax" :
44+ reconstructed = reconstructed [..., :sax_slices ]
45+ for j in range (sax_slices ):
46+ axs [3 + j , 0 ].set_ylabel (f"SAX slice { j } " )
47+ axs [3 + j , 0 ].imshow (image [..., j ], cmap = "gray" )
48+ axs [3 + j , 1 ].imshow (masks [..., j ], cmap = "gray" )
49+ axs [3 + j , 2 ].imshow (reconstructed [..., j ], cmap = "gray" )
50+ axs [3 + j , 3 ].imshow (error [..., j ], cmap = "gray" )
51+ else :
52+ axs [i , 0 ].imshow (image , cmap = "gray" )
53+ axs [i , 1 ].imshow (masks , cmap = "gray" )
54+ axs [i , 2 ].imshow (reconstructed , cmap = "gray" )
55+ axs [i , 3 ].imshow (error , cmap = "gray" )
56+ axs [i , 0 ].set_ylabel ({"lax_2c" : "LAX 2C" , "lax_3c" : "LAX 3C" , "lax_4c" : "LAX 4C" }[view ])
57+ if i == 0 :
58+ axs [i , 0 ].set_title ("Original" )
59+ axs [i , 1 ].set_title ("Mask" )
60+ axs [i , 2 ].set_title ("Reconstructed" )
61+ axs [i , 3 ].set_title ("Error" )
62+ # remove the x and y ticks
63+ for i in range (n_rows ):
64+ for j in range (n_cols ):
65+ axs [i , j ].set_xticks ([])
66+ axs [i , j ].set_yticks ([])
67+ fig .tight_layout ()
68+ fig .subplots_adjust (wspace = 0 , hspace = 0 )
69+ return fig
70+
71+
1472def run (device : torch .device , dtype : torch .dtype ) -> None :
1573 """Run MAE reconstruction."""
74+ t = 25 # which time frame to use
75+
1676 # load model
1777 model = CineMA .from_pretrained ()
18- model .to (device )
1978 model .eval ()
79+ model .to (device )
2080
2181 # load sample data and form a batch of size 1
2282 transform = Compose (
@@ -46,7 +106,7 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
46106 lax_4c_image = torch .from_numpy (
47107 np .transpose (sitk .GetArrayFromImage (sitk .ReadImage (exp_dir / "data/ukb/1/1_lax_4c.nii.gz" )))
48108 )
49- t = 25 # which time frame to use
109+ sax_slices = sax_image . shape [ - 2 ]
50110 batch = {
51111 "sax" : sax_image [None , ..., t ],
52112 "lax_2c" : lax_2c_image [None , ..., 0 , t ],
@@ -62,45 +122,8 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
62122 _ , pred_dict , enc_mask_dict , _ = model (batch , enc_mask_ratio = 0.75 )
63123
64124 # visualize
65- _ , axs = plt .subplots (6 , 4 , figsize = (12 , 18 ))
66- for i , view in enumerate (["lax_2c" , "lax_3c" , "lax_4c" , "sax" ]):
67- patches = patchify (image = batch [view ], patch_size = model .dec_patch_size_dict [view ])
68- patches [enc_mask_dict [view ]] = pred_dict [view ]
69- masks = torch .zeros_like (patches )
70- masks [enc_mask_dict [view ]] = 1
71- masks = unpatchify (
72- masks , patch_size = model .dec_patch_size_dict [view ], grid_size = model .enc_down_dict [view ].patch_embed .grid_size
73- )
74- masks = masks [0 , 0 ]
75- reconstructed = unpatchify (
76- patches ,
77- patch_size = model .dec_patch_size_dict [view ],
78- grid_size = model .enc_down_dict [view ].patch_embed .grid_size ,
79- )
80- reconstructed = reconstructed [0 , 0 ].detach ().cpu ().numpy ()
81- image = batch [view ][0 , 0 ].detach ().cpu ().numpy ()
82- error = np .abs (reconstructed - image )
83-
84- if view == "sax" :
85- for j in range (3 ):
86- z = j * 3
87- axs [3 + j , 0 ].set_ylabel (f"{ view } slice { z } " )
88- axs [3 + j , 0 ].imshow (image [..., z ], cmap = "gray" )
89- axs [3 + j , 1 ].imshow (masks [..., z ], cmap = "gray" )
90- axs [3 + j , 2 ].imshow (reconstructed [..., z ], cmap = "gray" )
91- axs [3 + j , 3 ].imshow (error [..., z ], cmap = "gray" )
92- else :
93- axs [i , 0 ].imshow (image , cmap = "gray" )
94- axs [i , 1 ].imshow (masks , cmap = "gray" )
95- axs [i , 2 ].imshow (reconstructed , cmap = "gray" )
96- axs [i , 3 ].imshow (error , cmap = "gray" )
97- axs [i , 0 ].set_ylabel (view )
98- if i == 0 :
99- axs [i , 0 ].set_title ("Original" )
100- axs [i , 1 ].set_title ("Mask" )
101- axs [i , 2 ].set_title ("Reconstructed" )
102- axs [i , 3 ].set_title ("Error" )
103- plt .savefig ("mae_reconstruction.png" , dpi = 300 , bbox_inches = "tight" )
125+ fig = plot_mae_reconstruction (model , batch , pred_dict , enc_mask_dict , sax_slices )
126+ fig .savefig ("mae_reconstruction.png" , dpi = 300 , bbox_inches = "tight" )
104127 plt .show (block = False )
105128
106129
0 commit comments