@@ -219,32 +219,17 @@ def benchmark(
219
219
times .append (inf_time )
220
220
221
221
if save_video :
222
- # Visualize keypoints
223
- this_pose = pose ["poses" ][0 ][0 ]
224
- for j in range (this_pose .shape [0 ]):
225
- if this_pose [j , 2 ] > pcutoff :
226
- x , y = map (int , this_pose [j , :2 ])
227
- cv2 .circle (
228
- frame ,
229
- center = (x , y ),
230
- radius = display_radius ,
231
- color = colors [j ],
232
- thickness = - 1 ,
233
- )
222
+ draw_pose_and_write (
223
+ frame = frame ,
224
+ pose = pose ,
225
+ colors = colors ,
226
+ bodyparts = bodyparts ,
227
+ pcutoff = pcutoff ,
228
+ display_radius = display_radius ,
229
+ draw_keypoint_names = draw_keypoint_names ,
230
+ vwriter = vwriter
231
+ )
234
232
235
- if draw_keypoint_names :
236
- cv2 .putText (
237
- frame ,
238
- text = bodyparts [j ],
239
- org = (x + 10 , y ),
240
- fontFace = cv2 .FONT_HERSHEY_SIMPLEX ,
241
- fontScale = 0.5 ,
242
- color = colors [j ],
243
- thickness = 1 ,
244
- lineType = cv2 .LINE_AA ,
245
- )
246
-
247
- vwriter .write (image = frame )
248
233
frame_index += 1
249
234
250
235
cap .release ()
@@ -291,6 +276,47 @@ def setup_video_writer(
291
276
292
277
return colors , vwriter
293
278
279
+ def draw_pose_and_write (
280
+ frame : np .ndarray ,
281
+ pose : np .ndarray ,
282
+ colors : list [tuple [int , int , int ]],
283
+ bodyparts : list [str ],
284
+ pcutoff : float ,
285
+ display_radius : int ,
286
+ draw_keypoint_names : bool ,
287
+ vwriter : cv2 .VideoWriter ,
288
+ ):
289
+ if len (pose .shape ) == 2 :
290
+ pose = pose [None ]
291
+
292
+ # Visualize keypoints
293
+ for i in range (pose .shape [0 ]):
294
+ for j in range (pose .shape [1 ]):
295
+ if pose [i , j , 2 ] > pcutoff :
296
+ x , y = map (int , pose [i , j , :2 ])
297
+ cv2 .circle (
298
+ frame ,
299
+ center = (x , y ),
300
+ radius = display_radius ,
301
+ color = colors [j ],
302
+ thickness = - 1 ,
303
+ )
304
+
305
+ if draw_keypoint_names :
306
+ cv2 .putText (
307
+ frame ,
308
+ text = bodyparts [j ],
309
+ org = (x + 10 , y ),
310
+ fontFace = cv2 .FONT_HERSHEY_SIMPLEX ,
311
+ fontScale = 0.5 ,
312
+ color = colors [j ],
313
+ thickness = 1 ,
314
+ lineType = cv2 .LINE_AA ,
315
+ )
316
+
317
+
318
+ vwriter .write (image = frame )
319
+
294
320
def save_poses_to_files (video_path , save_dir , bodyparts , poses , timestamp ):
295
321
"""
296
322
Saves the detected keypoint poses from the video to CSV and HDF5 files.
0 commit comments