@@ -138,6 +138,55 @@ def plot_lv(coords: np.ndarray, filepath: Path) -> None:
138138 plt .close (fig )
139139
140140
141+ def plot_heatmap_and_landmarks (images : np .ndarray , probs : np .ndarray , coords : np .ndarray , filepath : Path ) -> None :
142+ """Plot combined heatmap and landmarks as animated GIF.
143+
144+ Args:
145+ images: (x, y, 1, t)
146+ probs: (3, x, y, t)
147+ coords: (6, t)
148+ filepath: path to save the GIF file.
149+ """
150+ n_frames = probs .shape [- 1 ]
151+ frames = []
152+
153+ for t in tqdm (range (n_frames ), desc = "Creating combined GIF frames" ):
154+ # Create single frame with image + colored heatmaps + landmarks
155+ fig , ax = plt .subplots (figsize = (5 , 5 ), dpi = 150 )
156+
157+ # Plot original image as background
158+ ax .imshow (images [..., 0 , t ], cmap = "gray" )
159+
160+ # Plot heatmap
161+ ax .imshow (probs [0 , ..., t , None ] * np .array ([1 , 0 , 0 , 0.6 ]))
162+ ax .imshow (probs [1 , ..., t , None ] * np .array ([1 , 0 , 0 , 0.6 ]))
163+ ax .imshow (probs [2 , ..., t , None ] * np .array ([1 , 0 , 0 , 0.6 ]))
164+
165+ # Add landmark crosses on top
166+ for k in range (3 ):
167+ pred_x , pred_y = coords [2 * k , t ], coords [2 * k + 1 , t ]
168+ ax .plot ([pred_y - 9 , pred_y + 9 ], [pred_x , pred_x ], color = "red" , linewidth = 2 )
169+ ax .plot ([pred_y , pred_y ], [pred_x - 9 , pred_x + 9 ], color = "red" , linewidth = 2 )
170+
171+ ax .set_xticks ([])
172+ ax .set_yticks ([])
173+
174+ # Render figure to numpy array using BytesIO
175+ buf = io .BytesIO ()
176+ plt .savefig (buf , format = "png" , bbox_inches = "tight" , pad_inches = 0 , dpi = 150 )
177+ buf .seek (0 )
178+ img = Image .open (buf )
179+ frame = np .array (img .convert ("RGB" ))
180+ frames .append (frame )
181+ buf .close ()
182+ plt .close (fig )
183+
184+ # Create GIF directly from memory arrays
185+ with imageio .get_writer (filepath , mode = "I" , duration = 100 , loop = 0 ) as writer :
186+ for frame in tqdm (frames , desc = "Creating combined GIF" ):
187+ writer .append_data (frame )
188+
189+
141190def run (view : str , seed : int , device : torch .device , dtype : torch .dtype ) -> None :
142191 """Run landmark localization on LAX images using fine-tuned checkpoint."""
143192 # load model
@@ -180,6 +229,9 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
180229 # visualise LV length changes
181230 plot_lv (coords , Path (f"landmark_heatmap_gls_{ view } _{ seed } .png" ))
182231
232+ # visualise heatmap and landmarks
233+ plot_heatmap_and_landmarks (images , probs , coords , Path (f"landmark_heatmap_probs_and_landmark_{ view } _{ seed } .gif" ))
234+
183235
184236if __name__ == "__main__" :
185237 dtype , device = torch .float32 , torch .device ("cpu" )
0 commit comments