Skip to content

Commit 1346057

Browse files
committed
Store gif instead of grid subplot
1 parent af1958f commit 1346057

File tree

6 files changed

+216
-163
lines changed

6 files changed

+216
-163
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,4 @@ wandb/
151151
# docs
152152
out/
153153
**/*.png
154+
**/*.gif

cinema/examples/inference/landmark_coordinate.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pathlib import Path
44

5-
import matplotlib.pyplot as plt
65
import numpy as np
76
import SimpleITK as sitk # noqa: N813
87
import torch
@@ -43,14 +42,10 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
4342
coords = np.stack(coords_list, axis=-1) # (6, t)
4443

4544
# visualise landmarks
46-
fig = plot_landmarks(images, coords)
47-
fig.savefig(f"landmark_coordinate_landmark_{view}_{seed}.png", dpi=300, bbox_inches="tight")
48-
plt.show(block=False)
45+
plot_landmarks(images, coords, Path(f"landmark_coordinate_landmark_{view}_{seed}.gif"))
4946

5047
# visualise LV length changes
51-
fig = plot_lv(coords)
52-
plt.savefig(f"landmark_coordinate_gls_{view}_{seed}.png", dpi=300, bbox_inches="tight")
53-
plt.show(block=False)
48+
plot_lv(coords, Path(f"landmark_coordinate_gls_{view}_{seed}.png"))
5449

5550

5651
if __name__ == "__main__":

cinema/examples/inference/landmark_heatmap.py

Lines changed: 86 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pathlib import Path
44

5+
import imageio
56
import matplotlib.pyplot as plt
67
import numpy as np
78
import SimpleITK as sitk # noqa: N813
@@ -12,82 +13,98 @@
1213
from cinema import ConvUNetR, heatmap_soft_argmax
1314

1415

15-
def plot_heatmaps(images: np.ndarray, probs: np.ndarray, n_cols: int = 5) -> plt.Figure:
16-
"""Plot heatmaps.
16+
def plot_heatmaps(images: np.ndarray, probs: np.ndarray, filepath: Path) -> None:
17+
"""Plot heatmaps as animated GIF.
1718
1819
Args:
19-
images: (x, y, t)
20+
images: (x, y, 1, t)
2021
probs: (3, x, y, t)
21-
n_cols: number of columns
22-
23-
Returns:
24-
figure
22+
filepath: path to save the GIF file.
2523
"""
2624
n_frames = probs.shape[-1]
27-
n_rows = n_frames // n_cols
28-
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), dpi=300)
29-
for i in range(n_rows):
30-
for j in range(n_cols):
31-
t = i * n_cols + j
32-
axs[i, j].imshow(images[..., 0, t], cmap="gray")
33-
axs[i, j].imshow(probs[0, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
34-
axs[i, j].imshow(probs[1, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
35-
axs[i, j].imshow(probs[2, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
36-
axs[i, j].set_xticks([])
37-
axs[i, j].set_yticks([])
38-
if j == 0:
39-
axs[i, j].set_ylabel(f"t = {t}")
40-
fig.tight_layout()
41-
fig.subplots_adjust(wspace=0, hspace=0)
42-
return fig
43-
44-
45-
def plot_landmarks(images: np.ndarray, coords: np.ndarray, n_cols: int = 5) -> plt.Figure:
46-
"""Plot landmarks.
25+
temp_frame_paths = []
26+
27+
for t in tqdm(range(n_frames), desc="Creating heatmap GIF frames"):
28+
# Create individual frame
29+
fig, ax = plt.subplots(figsize=(5, 5), dpi=300)
30+
31+
# Plot image
32+
ax.imshow(images[..., 0, t], cmap="gray")
33+
34+
# Plot heatmap overlays
35+
ax.imshow(probs[0, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
36+
ax.imshow(probs[1, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
37+
ax.imshow(probs[2, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
38+
39+
# Remove axes
40+
ax.set_xticks([])
41+
ax.set_yticks([])
42+
43+
# Save frame
44+
frame_path = f"_tmp_heatmap_frame_{t:03d}.png"
45+
plt.savefig(frame_path, bbox_inches="tight", pad_inches=0, dpi=300)
46+
plt.close(fig)
47+
temp_frame_paths.append(frame_path)
48+
49+
# Create GIF
50+
with imageio.get_writer(filepath, mode="I", duration=100, loop=0) as writer:
51+
for frame_path in tqdm(temp_frame_paths, desc="Creating heatmap GIF"):
52+
image = imageio.v2.imread(frame_path)
53+
writer.append_data(image)
54+
# Clean up temporary file
55+
Path(frame_path).unlink()
56+
57+
58+
def plot_landmarks(images: np.ndarray, coords: np.ndarray, filepath: Path) -> None:
59+
"""Plot landmarks as animated GIF.
4760
4861
Args:
49-
images: (x, y, t)
62+
images: (x, y, 1, t)
5063
coords: (6, t)
51-
n_cols: number of columns
52-
53-
Returns:
54-
figure
64+
filepath: path to save the GIF file.
5565
"""
5666
n_frames = images.shape[-1]
57-
n_rows = n_frames // n_cols
58-
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), dpi=300)
59-
for i in range(n_rows):
60-
for j in range(n_cols):
61-
t = i * n_cols + j
62-
63-
# draw predictions with cross
64-
preds = images[..., t] * np.array([1, 1, 1])[None, None, :]
65-
preds = preds.clip(0, 255).astype(np.uint8)
66-
for k in range(3):
67-
pred_x, pred_y = coords[2 * k, t], coords[2 * k + 1, t]
68-
x1, x2 = max(0, pred_x - 9), min(preds.shape[0], pred_x + 10)
69-
y1, y2 = max(0, pred_y - 9), min(preds.shape[1], pred_y + 10)
70-
preds[pred_x, y1:y2] = [255, 0, 0]
71-
preds[x1:x2, pred_y] = [255, 0, 0]
72-
73-
axs[i, j].imshow(preds)
74-
axs[i, j].set_xticks([])
75-
axs[i, j].set_yticks([])
76-
if j == 0:
77-
axs[i, j].set_ylabel(f"t = {t}")
78-
fig.tight_layout()
79-
fig.subplots_adjust(wspace=0, hspace=0)
80-
return fig
81-
82-
83-
def plot_lv(coords: np.ndarray) -> plt.Figure:
67+
temp_frame_paths = []
68+
69+
for t in tqdm(range(n_frames), desc="Creating landmark GIF frames"):
70+
# Create individual frame
71+
fig, ax = plt.subplots(figsize=(5, 5), dpi=300)
72+
73+
# draw predictions with cross
74+
preds = images[..., t] * np.array([1, 1, 1])[None, None, :]
75+
preds = preds.clip(0, 255).astype(np.uint8)
76+
for k in range(3):
77+
pred_x, pred_y = coords[2 * k, t], coords[2 * k + 1, t]
78+
x1, x2 = max(0, pred_x - 9), min(preds.shape[0], pred_x + 10)
79+
y1, y2 = max(0, pred_y - 9), min(preds.shape[1], pred_y + 10)
80+
preds[pred_x, y1:y2] = [255, 0, 0]
81+
preds[x1:x2, pred_y] = [255, 0, 0]
82+
83+
ax.imshow(preds)
84+
ax.set_xticks([])
85+
ax.set_yticks([])
86+
87+
# Save frame
88+
frame_path = f"_tmp_landmark_frame_{t:03d}.png"
89+
plt.savefig(frame_path, bbox_inches="tight", pad_inches=0, dpi=300)
90+
plt.close(fig)
91+
temp_frame_paths.append(frame_path)
92+
93+
# Create GIF
94+
with imageio.get_writer(filepath, mode="I", duration=100, loop=0) as writer:
95+
for frame_path in tqdm(temp_frame_paths, desc="Creating landmark GIF"):
96+
image = imageio.v2.imread(frame_path)
97+
writer.append_data(image)
98+
# Clean up temporary file
99+
Path(frame_path).unlink()
100+
101+
102+
def plot_lv(coords: np.ndarray, filepath: Path) -> None:
84103
"""Plot GL shortening.
85104
86105
Args:
87106
coords: (6, t)
88-
89-
Returns:
90-
figure
107+
filepath: path to save the PNG file.
91108
"""
92109
# GL shortening
93110
x1, y1 = coords[0], coords[1]
@@ -108,12 +125,13 @@ def plot_lv(coords: np.ndarray) -> plt.Figure:
108125
) / 2
109126

110127
fig = plt.figure(figsize=(4, 4), dpi=120)
111-
plt.plot(lv_lengths, color="#82B366", label="LV")
128+
plt.plot(lv_lengths, color="#82B366", label="Left Ventricle")
112129
plt.xlabel("Frame")
113130
plt.ylabel("Length (mm)")
114-
plt.title(f"GLS = {gls:.2f}%, MAPSE = {mapse:.2f} mm")
131+
plt.title(f"GLS = {gls:.2f}%\nMAPSE = {mapse:.2f} mm")
115132
plt.legend(loc="lower right")
116-
return fig
133+
fig.savefig(filepath, dpi=300, bbox_inches="tight")
134+
plt.close(fig)
117135

118136

119137
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
@@ -150,19 +168,13 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
150168
coords = np.stack(coords_list, axis=-1) # (6, t)
151169

152170
# visualise heatmaps
153-
fig = plot_heatmaps(images, probs)
154-
fig.savefig(f"landmark_heatmap_probs_{view}_{seed}.png", dpi=300, bbox_inches="tight")
155-
plt.show(block=False)
171+
plot_heatmaps(images, probs, Path(f"landmark_heatmap_probs_{view}_{seed}.gif"))
156172

157173
# visualise landmarks
158-
fig = plot_landmarks(images, coords)
159-
fig.savefig(f"landmark_heatmap_landmark_{view}_{seed}.png", dpi=300, bbox_inches="tight")
160-
plt.show(block=False)
174+
plot_landmarks(images, coords, Path(f"landmark_heatmap_landmark_{view}_{seed}.gif"))
161175

162176
# visualise LV length changes
163-
fig = plot_lv(coords)
164-
plt.savefig(f"landmark_heatmap_gls_{view}_{seed}.png", dpi=300, bbox_inches="tight")
165-
plt.show(block=False)
177+
plot_lv(coords, Path(f"landmark_heatmap_gls_{view}_{seed}.png"))
166178

167179

168180
if __name__ == "__main__":

cinema/examples/inference/mae.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,16 @@ def plot_mae_reconstruction(
1515
image_dict: dict[str, torch.Tensor],
1616
reconstructed_dict: dict[str, torch.Tensor],
1717
masks_dict: dict[str, torch.Tensor],
18-
) -> plt.Figure:
19-
"""Plot MAE reconstruction."""
18+
filepath: Path,
19+
) -> None:
20+
"""Plot MAE reconstruction.
21+
22+
Args:
23+
image_dict: Dictionary of original images
24+
reconstructed_dict: Dictionary of reconstructed images
25+
masks_dict: Dictionary of masks
26+
filepath: path to save the PNG file.
27+
"""
2028
sax_slices = image_dict["sax"].shape[-1]
2129
n_rows = sax_slices + 3
2230
n_cols = 4
@@ -52,7 +60,8 @@ def plot_mae_reconstruction(
5260
axs[i, j].set_yticks([])
5361
fig.tight_layout()
5462
fig.subplots_adjust(wspace=0, hspace=0)
55-
return fig
63+
fig.savefig(filepath, dpi=300, bbox_inches="tight")
64+
plt.close(fig)
5665

5766

5867
def reconstruct_images(
@@ -141,12 +150,12 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
141150
batch["sax"] = batch["sax"][..., :sax_slices]
142151

143152
# visualize
144-
fig = plot_mae_reconstruction(
153+
plot_mae_reconstruction(
145154
batch,
146155
reconstructed_dict,
147156
masks_dict,
157+
Path("mae_reconstruction.png"),
148158
)
149-
fig.savefig("mae_reconstruction.png", dpi=300, bbox_inches="tight")
150159
plt.show(block=False)
151160

152161

0 commit comments

Comments
 (0)