Skip to content

Commit dd9c19c

Browse files
committed
Add visualisations for landmark
1 parent 0d1afe8 commit dd9c19c

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

cinema/examples/inference/landmark_heatmap.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,55 @@ def plot_lv(coords: np.ndarray, filepath: Path) -> None:
138138
plt.close(fig)
139139

140140

141+
def plot_heatmap_and_landmarks(images: np.ndarray, probs: np.ndarray, coords: np.ndarray, filepath: Path) -> None:
142+
"""Plot combined heatmap and landmarks as animated GIF.
143+
144+
Args:
145+
images: (x, y, 1, t)
146+
probs: (3, x, y, t)
147+
coords: (6, t)
148+
filepath: path to save the GIF file.
149+
"""
150+
n_frames = probs.shape[-1]
151+
frames = []
152+
153+
for t in tqdm(range(n_frames), desc="Creating combined GIF frames"):
154+
# Create single frame with image + colored heatmaps + landmarks
155+
fig, ax = plt.subplots(figsize=(5, 5), dpi=150)
156+
157+
# Plot original image as background
158+
ax.imshow(images[..., 0, t], cmap="gray")
159+
160+
# Plot heatmap
161+
ax.imshow(probs[0, ..., t, None] * np.array([1, 0, 0, 0.6]))
162+
ax.imshow(probs[1, ..., t, None] * np.array([1, 0, 0, 0.6]))
163+
ax.imshow(probs[2, ..., t, None] * np.array([1, 0, 0, 0.6]))
164+
165+
# Add landmark crosses on top
166+
for k in range(3):
167+
pred_x, pred_y = coords[2 * k, t], coords[2 * k + 1, t]
168+
ax.plot([pred_y - 9, pred_y + 9], [pred_x, pred_x], color="red", linewidth=2)
169+
ax.plot([pred_y, pred_y], [pred_x - 9, pred_x + 9], color="red", linewidth=2)
170+
171+
ax.set_xticks([])
172+
ax.set_yticks([])
173+
174+
# Render figure to numpy array using BytesIO
175+
buf = io.BytesIO()
176+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=150)
177+
buf.seek(0)
178+
img = Image.open(buf)
179+
frame = np.array(img.convert("RGB"))
180+
frames.append(frame)
181+
buf.close()
182+
plt.close(fig)
183+
184+
# Create GIF directly from memory arrays
185+
with imageio.get_writer(filepath, mode="I", duration=100, loop=0) as writer:
186+
for frame in tqdm(frames, desc="Creating combined GIF"):
187+
writer.append_data(frame)
188+
189+
141190
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
142191
"""Run landmark localization on LAX images using fine-tuned checkpoint."""
143192
# load model
@@ -180,6 +229,9 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
180229
# visualise LV length changes
181230
plot_lv(coords, Path(f"landmark_heatmap_gls_{view}_{seed}.png"))
182231

232+
# visualise heatmap and landmarks
233+
plot_heatmap_and_landmarks(images, probs, coords, Path(f"landmark_heatmap_probs_and_landmark_{view}_{seed}.gif"))
234+
183235

184236
if __name__ == "__main__":
185237
dtype, device = torch.float32, torch.device("cpu")

0 commit comments

Comments
 (0)