Skip to content

Commit ad62af4

Browse files
committed
Benchmark pytorch: extract draw_pose_and_write()
1 parent 7227c98 commit ad62af4

File tree

1 file changed

+51
-25
lines changed

1 file changed

+51
-25
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -219,32 +219,17 @@ def benchmark(
219219
times.append(inf_time)
220220

221221
if save_video:
222-
# Visualize keypoints
223-
this_pose = pose["poses"][0][0]
224-
for j in range(this_pose.shape[0]):
225-
if this_pose[j, 2] > pcutoff:
226-
x, y = map(int, this_pose[j, :2])
227-
cv2.circle(
228-
frame,
229-
center=(x, y),
230-
radius=display_radius,
231-
color=colors[j],
232-
thickness=-1,
233-
)
222+
draw_pose_and_write(
223+
frame=frame,
224+
pose=pose,
225+
colors=colors,
226+
bodyparts=bodyparts,
227+
pcutoff=pcutoff,
228+
display_radius=display_radius,
229+
draw_keypoint_names=draw_keypoint_names,
230+
vwriter=vwriter
231+
)
234232

235-
if draw_keypoint_names:
236-
cv2.putText(
237-
frame,
238-
text=bodyparts[j],
239-
org=(x + 10, y),
240-
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
241-
fontScale=0.5,
242-
color=colors[j],
243-
thickness=1,
244-
lineType=cv2.LINE_AA,
245-
)
246-
247-
vwriter.write(image=frame)
248233
frame_index += 1
249234

250235
cap.release()
@@ -291,6 +276,47 @@ def setup_video_writer(
291276

292277
return colors, vwriter
293278

279+
def draw_pose_and_write(
280+
frame: np.ndarray,
281+
pose: np.ndarray,
282+
colors: list[tuple[int, int, int]],
283+
bodyparts: list[str],
284+
pcutoff: float,
285+
display_radius: int,
286+
draw_keypoint_names: bool,
287+
vwriter: cv2.VideoWriter,
288+
):
289+
if len(pose.shape) == 2:
290+
pose = pose[None]
291+
292+
# Visualize keypoints
293+
for i in range(pose.shape[0]):
294+
for j in range(pose.shape[1]):
295+
if pose[i, j, 2] > pcutoff:
296+
x, y = map(int, pose[i, j, :2])
297+
cv2.circle(
298+
frame,
299+
center=(x, y),
300+
radius=display_radius,
301+
color=colors[j],
302+
thickness=-1,
303+
)
304+
305+
if draw_keypoint_names:
306+
cv2.putText(
307+
frame,
308+
text=bodyparts[j],
309+
org=(x + 10, y),
310+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
311+
fontScale=0.5,
312+
color=colors[j],
313+
thickness=1,
314+
lineType=cv2.LINE_AA,
315+
)
316+
317+
318+
vwriter.write(image=frame)
319+
294320
def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
295321
"""
296322
Saves the detected keypoint poses from the video to CSV and HDF5 files.

0 commit comments

Comments
 (0)