Skip to content

Commit f8d5e69

Browse files
committed
.
1 parent 0f50210 commit f8d5e69

File tree

1 file changed

+12
-32
lines changed

1 file changed

+12
-32
lines changed

examples/basic_cuda_example.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
#
131131
# Let's look at the frames decoded by CUDA decoder and compare them
132132
# against equivalent results from the CPU decoders.
133-
from typing import List, Optional
133+
import matplotlib.pyplot as plt
134134

135135

136136
def get_frames(timestamps: list[float], device: str):
@@ -149,45 +149,25 @@ def get_numpy_images(frames):
149149

150150
timestamps = [12, 19, 45, 131, 180]
151151
cpu_frames = get_frames(timestamps, device="cpu")
152-
cuda_frames = get_frames(timestamps, device="cuda")
153-
cpu_tensors = [frame.data for frame in cpu_frames]
154-
cuda_tensors = [frame.data for frame in cuda_frames]
152+
cuda_frames = get_frames(timestamps, device="cuda:0")
155153
cpu_numpy_images = get_numpy_images(cpu_frames)
156154
cuda_numpy_images = get_numpy_images(cuda_frames)
157155

158156

159-
def plot(
160-
frames1: List[torch.Tensor],
161-
frames2: List[torch.Tensor],
162-
title1: Optional[str] = None,
163-
title2: Optional[str] = None,
164-
):
165-
try:
166-
import matplotlib.pyplot as plt
167-
from torchvision.transforms.v2.functional import to_pil_image
168-
from torchvision.utils import make_grid
169-
except ImportError:
170-
print("Cannot plot, please run `pip install torchvision matplotlib`")
171-
return
172-
173-
plt.rcParams["savefig.bbox"] = "tight"
174-
175-
fig, ax = plt.subplots(1, 2)
176-
177-
ax[0].imshow(to_pil_image(make_grid(frames1, nrow=1)))
178-
ax[0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
179-
if title1 is not None:
180-
ax[0].set_title(title1)
181-
182-
ax[1].imshow(to_pil_image(make_grid(frames2, nrow=1)))
183-
ax[1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
184-
if title2 is not None:
185-
ax[1].set_title(title2)
157+
def plot_cpu_and_cuda_images():
158+
n_rows = len(timestamps)
159+
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
160+
for i in range(n_rows):
161+
axes[i][0].imshow(cpu_numpy_images[i])
162+
axes[i][1].imshow(cuda_numpy_images[i])
186163

164+
axes[0][0].set_title("CPU decoder")
165+
axes[0][1].set_title("CUDA decoder")
166+
plt.setp(axes, xticks=[], yticks=[])
187167
plt.tight_layout()
188168

189169

190-
plot(cpu_tensors, cuda_tensors, "CPU decoder", "CUDA decoder")
170+
plot_cpu_and_cuda_images()
191171

192172
# %%
193173
#

0 commit comments

Comments
 (0)