Skip to content

Commit 015f355

Browse files
committed
.
1 parent 3003be2 commit 015f355

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

examples/basic_cuda_example.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,43 +133,35 @@
133133
# Let's look at the frames decoded by CUDA decoder and compare them
134134
# against equivalent results from the CPU decoders.
135135
import matplotlib.pyplot as plt
136+
from torchvision.transforms.v2.functional import to_pil_image
136137

137138

138139
def get_frames(timestamps: list[float], device: str):
139140
decoder = VideoDecoder(video_file, device=device)
140-
return [decoder.get_frame_played_at(ts) for ts in timestamps]
141-
142-
143-
def get_numpy_images(frames):
144-
numpy_images = []
145-
for frame in frames:
146-
# We transfer to the CPU so they can be visualized by matplotlib.
147-
numpy_image = frame.data.to("cpu").permute(1, 2, 0).numpy()
148-
numpy_images.append(numpy_image)
149-
return numpy_images
141+
return [decoder.get_frame_played_at(ts).data for ts in timestamps]
150142

151143

152144
timestamps = [12, 19, 45, 131, 180]
153145
cpu_frames = get_frames(timestamps, device="cpu")
154146
cuda_frames = get_frames(timestamps, device="cuda:0")
155-
cpu_numpy_images = get_numpy_images(cpu_frames)
156-
cuda_numpy_images = get_numpy_images(cuda_frames)
157147

158148

159-
def plot_cpu_and_cuda_images():
149+
def plot_cpu_and_cuda_frames(
150+
cpu_frames: list[torch.Tensor], cuda_frames: list[torch.Tensor]
151+
):
160152
n_rows = len(timestamps)
161153
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
162154
for i in range(n_rows):
163-
axes[i][0].imshow(cpu_numpy_images[i])
164-
axes[i][1].imshow(cuda_numpy_images[i])
155+
axes[i][0].imshow(to_pil_image(cpu_frames[i].to("cpu")))
156+
axes[i][1].imshow(to_pil_image(cuda_frames[i].to("cpu")))
165157

166-
axes[0][0].set_title("CPU decoder")
167-
axes[0][1].set_title("CUDA decoder")
158+
axes[0][0].set_title("CPU decoder", fontsize=24)
159+
axes[0][1].set_title("CUDA decoder", fontsize=24)
168160
plt.setp(axes, xticks=[], yticks=[])
169161
plt.tight_layout()
170162

171163

172-
plot_cpu_and_cuda_images()
164+
plot_cpu_and_cuda_frames(cpu_frames, cuda_frames)
173165

174166
# %%
175167
#

0 commit comments

Comments
 (0)