Skip to content

Commit 7e86ffc

Browse files
committed
Accelerate gif making
1 parent 1346057 commit 7e86ffc

File tree

3 files changed

+60
-50
lines changed

3 files changed

+60
-50
lines changed

cinema/examples/inference/landmark_heatmap.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Example script to perform landmark localization on LAX images using fine-tuned checkpoint."""
22

3+
import io
34
from pathlib import Path
45

56
import imageio
@@ -8,6 +9,7 @@
89
import SimpleITK as sitk # noqa: N813
910
import torch
1011
from monai.transforms import ScaleIntensityd
12+
from PIL import Image
1113
from tqdm import tqdm
1214

1315
from cinema import ConvUNetR, heatmap_soft_argmax
@@ -22,11 +24,11 @@ def plot_heatmaps(images: np.ndarray, probs: np.ndarray, filepath: Path) -> None
2224
filepath: path to save the GIF file.
2325
"""
2426
n_frames = probs.shape[-1]
25-
temp_frame_paths = []
27+
frames = []
2628

2729
for t in tqdm(range(n_frames), desc="Creating heatmap GIF frames"):
2830
# Create individual frame
29-
fig, ax = plt.subplots(figsize=(5, 5), dpi=300)
31+
fig, ax = plt.subplots(figsize=(5, 5), dpi=150)
3032

3133
# Plot image
3234
ax.imshow(images[..., 0, t], cmap="gray")
@@ -40,19 +42,20 @@ def plot_heatmaps(images: np.ndarray, probs: np.ndarray, filepath: Path) -> None
4042
ax.set_xticks([])
4143
ax.set_yticks([])
4244

43-
# Save frame
44-
frame_path = f"_tmp_heatmap_frame_{t:03d}.png"
45-
plt.savefig(frame_path, bbox_inches="tight", pad_inches=0, dpi=300)
45+
# Render figure to numpy array using BytesIO (universal across backends)
46+
buf = io.BytesIO()
47+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=150)
48+
buf.seek(0)
49+
img = Image.open(buf)
50+
frame = np.array(img.convert("RGB"))
51+
frames.append(frame)
52+
buf.close()
4653
plt.close(fig)
47-
temp_frame_paths.append(frame_path)
4854

49-
# Create GIF
55+
# Create GIF directly from memory arrays
5056
with imageio.get_writer(filepath, mode="I", duration=100, loop=0) as writer:
51-
for frame_path in tqdm(temp_frame_paths, desc="Creating heatmap GIF"):
52-
image = imageio.v2.imread(frame_path)
53-
writer.append_data(image)
54-
# Clean up temporary file
55-
Path(frame_path).unlink()
57+
for frame in tqdm(frames, desc="Creating heatmap GIF"):
58+
writer.append_data(frame)
5659

5760

5861
def plot_landmarks(images: np.ndarray, coords: np.ndarray, filepath: Path) -> None:
@@ -64,11 +67,11 @@ def plot_landmarks(images: np.ndarray, coords: np.ndarray, filepath: Path) -> No
6467
filepath: path to save the GIF file.
6568
"""
6669
n_frames = images.shape[-1]
67-
temp_frame_paths = []
70+
frames = []
6871

6972
for t in tqdm(range(n_frames), desc="Creating landmark GIF frames"):
7073
# Create individual frame
71-
fig, ax = plt.subplots(figsize=(5, 5), dpi=300)
74+
fig, ax = plt.subplots(figsize=(5, 5), dpi=150)
7275

7376
# draw predictions with cross
7477
preds = images[..., t] * np.array([1, 1, 1])[None, None, :]
@@ -84,19 +87,20 @@ def plot_landmarks(images: np.ndarray, coords: np.ndarray, filepath: Path) -> No
8487
ax.set_xticks([])
8588
ax.set_yticks([])
8689

87-
# Save frame
88-
frame_path = f"_tmp_landmark_frame_{t:03d}.png"
89-
plt.savefig(frame_path, bbox_inches="tight", pad_inches=0, dpi=300)
90+
# Render figure to numpy array using BytesIO (universal across backends)
91+
buf = io.BytesIO()
92+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=150)
93+
buf.seek(0)
94+
img = Image.open(buf)
95+
frame = np.array(img.convert("RGB"))
96+
frames.append(frame)
97+
buf.close()
9098
plt.close(fig)
91-
temp_frame_paths.append(frame_path)
9299

93-
# Create GIF
100+
# Create GIF directly from memory arrays
94101
with imageio.get_writer(filepath, mode="I", duration=100, loop=0) as writer:
95-
for frame_path in tqdm(temp_frame_paths, desc="Creating landmark GIF"):
96-
image = imageio.v2.imread(frame_path)
97-
writer.append_data(image)
98-
# Clean up temporary file
99-
Path(frame_path).unlink()
102+
for frame in tqdm(frames, desc="Creating landmark GIF"):
103+
writer.append_data(frame)
100104

101105

102106
def plot_lv(coords: np.ndarray, filepath: Path) -> None:

cinema/examples/inference/segmentation_lax_4c.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Example script to perform segmentation on LAX 4C images using fine-tuned checkpoint."""
22

3+
import io
34
from pathlib import Path
45

56
import imageio
@@ -8,6 +9,7 @@
89
import SimpleITK as sitk # noqa: N813
910
import torch
1011
from monai.transforms import ScaleIntensityd
12+
from PIL import Image
1113
from scipy.spatial.distance import cdist
1214
from skimage import measure
1315
from tqdm import tqdm
@@ -51,11 +53,11 @@ def plot_segmentations(images: np.ndarray, labels: np.ndarray, filepath: Path) -
5153
filepath: path to save the GIF file.
5254
"""
5355
n_frames = labels.shape[-1]
54-
temp_frame_paths = []
56+
frames = []
5557

5658
for t in tqdm(range(n_frames), desc="Creating GIF frames"):
5759
# Create individual frame
58-
fig, ax = plt.subplots(figsize=(5, 5), dpi=300)
60+
fig, ax = plt.subplots(figsize=(5, 5), dpi=150)
5961

6062
# Plot image
6163
ax.imshow(images[..., 0, t], cmap="gray")
@@ -69,19 +71,20 @@ def plot_segmentations(images: np.ndarray, labels: np.ndarray, filepath: Path) -
6971
ax.set_xticks([])
7072
ax.set_yticks([])
7173

72-
# Save frame
73-
frame_path = f"_tmp_frame_{t:03d}.png"
74-
plt.savefig(frame_path, bbox_inches="tight", pad_inches=0, dpi=300)
74+
# Render figure to numpy array using BytesIO (universal across backends)
75+
buf = io.BytesIO()
76+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=150)
77+
buf.seek(0)
78+
img = Image.open(buf)
79+
frame = np.array(img.convert("RGB"))
80+
frames.append(frame)
81+
buf.close()
7582
plt.close(fig)
76-
temp_frame_paths.append(frame_path)
7783

78-
# Create GIF
84+
# Create GIF directly from memory arrays
7985
with imageio.get_writer(filepath, mode="I", duration=100, loop=0) as writer:
80-
for frame_path in tqdm(temp_frame_paths, desc="Creating GIF"):
81-
image = imageio.v2.imread(frame_path)
82-
writer.append_data(image)
83-
# Clean up temporary file
84-
Path(frame_path).unlink()
86+
for frame in tqdm(frames, desc="Creating GIF"):
87+
writer.append_data(frame)
8588

8689

8790
def plot_volume_changes(labels: np.ndarray, filepath: Path) -> None:

cinema/examples/inference/segmentation_sax.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Example script to perform segmentation on SAX images using fine-tuned checkpoint."""
22

3+
import io
34
from pathlib import Path
45

56
import imageio
@@ -8,6 +9,7 @@
89
import SimpleITK as sitk # noqa: N813
910
import torch
1011
from monai.transforms import Compose, ScaleIntensityd, SpatialPadd
12+
from PIL import Image
1113
from tqdm import tqdm
1214

1315
from cinema import ConvUNetR
@@ -23,12 +25,12 @@ def plot_segmentations(images: np.ndarray, labels: np.ndarray, filepath: Path) -
2325
"""
2426
n_slices, n_frames = labels.shape[-2:]
2527
n_cols = 3
26-
n_rows = (n_slices + n_cols - 1) // n_cols # Calculate rows needed for 5 columns
27-
temp_frame_paths = []
28+
n_rows = (n_slices + n_cols - 1) // n_cols # Calculate rows needed for 3 columns
29+
frames = []
2830

2931
for t in tqdm(range(n_frames), desc="Creating segmentation GIF frames"):
30-
# Create individual frame with SAX slices in grid layout (5 columns)
31-
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), dpi=300)
32+
# Create individual frame with SAX slices in grid layout (3 columns)
33+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), dpi=150)
3234

3335
# Handle different subplot arrangements
3436
if n_rows == 1 and n_cols == 1:
@@ -59,19 +61,20 @@ def plot_segmentations(images: np.ndarray, labels: np.ndarray, filepath: Path) -
5961
fig.tight_layout()
6062
fig.subplots_adjust(wspace=0.0, hspace=0.0)
6163

62-
# Save frame
63-
frame_path = f"_tmp_sax_frame_{t:03d}.png"
64-
plt.savefig(frame_path, bbox_inches="tight", pad_inches=0, dpi=300)
64+
# Render figure to numpy array using BytesIO (universal across backends)
65+
buf = io.BytesIO()
66+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=150)
67+
buf.seek(0)
68+
img = Image.open(buf)
69+
frame = np.array(img.convert("RGB"))
70+
frames.append(frame)
71+
buf.close()
6572
plt.close(fig)
66-
temp_frame_paths.append(frame_path)
6773

68-
# Create GIF
74+
# Create GIF directly from memory arrays
6975
with imageio.get_writer(filepath, mode="I", duration=200, loop=0) as writer:
70-
for frame_path in tqdm(temp_frame_paths, desc="Creating segmentation GIF"):
71-
image = imageio.v2.imread(frame_path)
72-
writer.append_data(image)
73-
# Clean up temporary file
74-
Path(frame_path).unlink()
76+
for frame in tqdm(frames, desc="Creating segmentation GIF"):
77+
writer.append_data(frame)
7578

7679

7780
def plot_volume_changes(labels: np.ndarray, t_step: int, filepath: Path) -> None:

0 commit comments

Comments
 (0)