@@ -138,6 +138,56 @@ def plot_lv(coords: np.ndarray, filepath: Path) -> None:
138138 plt .close (fig )
139139
140140
141+ def plot_heatmap_and_landmarks (images : np .ndarray , probs : np .ndarray , coords : np .ndarray , filepath : Path ) -> None :
142+ """Plot combined heatmap and landmarks as animated GIF.
143+
144+ Args:
145+ images: (x, y, 1, t)
146+ probs: (3, x, y, t)
147+ coords: (6, t)
148+ filepath: path to save the GIF file.
149+ """
150+ n_frames = probs .shape [- 1 ]
151+ frames = []
152+
153+ for t in tqdm (range (n_frames ), desc = "Creating combined GIF frames" ):
154+ # Create single frame with image + colored heatmaps + landmarks
155+ fig , ax = plt .subplots (figsize = (5 , 5 ), dpi = 150 )
156+
157+ # Plot original image as background
158+ ax .imshow (images [..., 0 , t ], cmap = "gray" )
159+
160+ # Plot heatmap
161+ ax .imshow (probs [0 , ..., t , None ] * np .array ([1 , 0 , 0 , 0.6 ]))
162+ ax .imshow (probs [1 , ..., t , None ] * np .array ([1 , 0 , 0 , 0.6 ]))
163+ ax .imshow (probs [2 , ..., t , None ] * np .array ([1 , 0 , 0 , 0.6 ]))
164+
165+ # Add landmark crosses on top
166+ for k in range (3 ):
167+ pred_x , pred_y = coords [2 * k , t ], coords [2 * k + 1 , t ]
168+ ax .plot ([pred_y - 9 , pred_y + 9 ], [pred_x , pred_x ], color = "red" , linewidth = 2 )
169+ ax .plot ([pred_y , pred_y ], [pred_x - 9 , pred_x + 9 ], color = "red" , linewidth = 2 )
170+
171+ ax .set_xticks ([])
172+ ax .set_yticks ([])
173+ ax .set_title (f"Time Frame { t } " )
174+
175+ # Render figure to numpy array using BytesIO
176+ buf = io .BytesIO ()
177+ plt .savefig (buf , format = "png" , bbox_inches = "tight" , pad_inches = 0 , dpi = 150 )
178+ buf .seek (0 )
179+ img = Image .open (buf )
180+ frame = np .array (img .convert ("RGB" ))
181+ frames .append (frame )
182+ buf .close ()
183+ plt .close (fig )
184+
185+ # Create GIF directly from memory arrays
186+ with imageio .get_writer (filepath , mode = "I" , duration = 100 , loop = 0 ) as writer :
187+ for frame in tqdm (frames , desc = "Creating combined GIF" ):
188+ writer .append_data (frame )
189+
190+
141191def run (view : str , seed : int , device : torch .device , dtype : torch .dtype ) -> None :
142192 """Run landmark localization on LAX images using fine-tuned checkpoint."""
143193 # load model
@@ -180,6 +230,9 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
180230 # visualise LV length changes
181231 plot_lv (coords , Path (f"landmark_heatmap_gls_{ view } _{ seed } .png" ))
182232
233+ # visualise heatmap and landmarks
234+ plot_heatmap_and_landmarks (images , probs , coords , Path (f"landmark_heatmap_probs_and_landmark_{ view } _{ seed } .gif" ))
235+
183236
184237if __name__ == "__main__" :
185238 dtype , device = torch .float32 , torch .device ("cpu" )
0 commit comments