Skip to content

Commit 7227c98

Browse files
committed
Benchmark pytorch: extract setup_video_writer()
1 parent 5b10eb1 commit 7227c98

File tree

1 file changed

+41
-27
lines changed

1 file changed

+41
-27
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import colorcet as cc
88
import cv2
99
import h5py
10+
import numpy as np
1011
from pathlib import Path
1112
from PIL import ImageColor
1213
from pip._internal.operations import freeze
@@ -182,34 +183,16 @@ def benchmark(
182183

183184
# Retrieve bodypart names and number of keypoints
184185
bodyparts = dlc_live.read_config()["metadata"]["bodyparts"]
185-
num_keypoints = len(bodyparts)
186186

187-
if save_video:
188-
# Set colors and convert to RGB
189-
cmap_colors = getattr(cc, cmap)
190-
colors = [
191-
ImageColor.getrgb(color)
192-
for color in cmap_colors[:: int(len(cmap_colors) / num_keypoints)]
193-
]
194-
195-
# Define output video path
196-
video_name = os.path.splitext(os.path.basename(video_path))[0]
197-
output_video_path = os.path.join(
198-
save_dir, f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
199-
)
200-
201-
# Get video writer setup
202-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
203-
fps = cap.get(cv2.CAP_PROP_FPS)
204-
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
205-
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
206-
207-
vwriter = cv2.VideoWriter(
208-
filename=output_video_path,
209-
fourcc=fourcc,
210-
fps=fps,
211-
frameSize=(frame_width, frame_height),
212-
)
187+
colors, vwriter = setup_video_writer(
188+
video_path=video_path,
189+
save_dir=save_dir,
190+
timestamp=timestamp,
191+
num_keypoints=len(bodyparts),
192+
cmap=cmap,
193+
fps=cap.get(cv2.CAP_PROP_FPS),
194+
frame_size=(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))),
195+
)
213196

214197
# Start empty dict to save poses to for each frame
215198
poses, times = [], []
@@ -276,6 +259,37 @@ def benchmark(
276259

277260
return poses, times
278261

262+
def setup_video_writer(
263+
video_path:str,
264+
save_dir:str,
265+
timestamp:str,
266+
num_keypoints:int,
267+
cmap:str,
268+
fps:float,
269+
frame_size:tuple[int, int],
270+
):
271+
# Set colors and convert to RGB
272+
cmap_colors = getattr(cc, cmap)
273+
colors = [
274+
ImageColor.getrgb(color)
275+
for color in cmap_colors[:: int(len(cmap_colors) / num_keypoints)]
276+
]
277+
278+
# Define output video path
279+
video_path = Path(video_path)
280+
video_name = video_path.stem # filename without extension
281+
output_video_path = Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
282+
283+
# Get video writer setup
284+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
285+
vwriter = cv2.VideoWriter(
286+
filename=output_video_path,
287+
fourcc=fourcc,
288+
fps=fps,
289+
frameSize=frame_size,
290+
)
291+
292+
return colors, vwriter
279293

280294
def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
281295
"""

0 commit comments

Comments
 (0)