@@ -243,7 +243,9 @@ def benchmark(
243
243
print (get_system_info ())
244
244
245
245
if save_poses :
246
- save_poses_to_files (video_path , save_dir , bodyparts , poses , timestamp = timestamp )
246
+ individuals = dlc_live .read_config ()["metadata" ].get ("individuals" , [])
247
+ n_individuals = len (individuals ) or 1
248
+ save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp )
247
249
248
250
return poses , times
249
251
@@ -320,7 +322,7 @@ def draw_pose_and_write(
320
322
321
323
vwriter .write (image = frame )
322
324
323
- def save_poses_to_files (video_path , save_dir , bodyparts , poses , timestamp ):
325
+ def save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp ):
324
326
"""
325
327
Saves the detected keypoint poses from the video to CSV and HDF5 files.
326
328
@@ -339,65 +341,48 @@ def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
339
341
-------
340
342
None
341
343
"""
344
+ import pandas as pd
342
345
343
- base_filename = os .path .splitext (os .path .basename (video_path ))[0 ]
344
- csv_save_path = os .path .join (save_dir , f"{ base_filename } _poses_{ timestamp } .csv" )
345
- h5_save_path = os .path .join (save_dir , f"{ base_filename } _poses_{ timestamp } .h5" )
346
-
347
- # Save to CSV
348
- with open (csv_save_path , mode = "w" , newline = "" ) as file :
349
- writer = csv .writer (file )
350
- header = ["frame" ] + [
351
- f"{ bp } _{ axis } " for bp in bodyparts for axis in ["x" , "y" , "confidence" ]
352
- ]
353
- writer .writerow (header )
354
- for entry in poses :
355
- frame_num = entry ["frame" ]
356
- pose = entry ["pose" ]["poses" ][0 ][0 ]
357
- row = [frame_num ] + [
358
- item .item () if isinstance (item , torch .Tensor ) else item
359
- for kp in pose
360
- for item in kp
361
- ]
362
- writer .writerow (row )
363
-
364
- # Save to HDF5
365
- with h5py .File (h5_save_path , "w" ) as hf :
366
- hf .create_dataset (name = "frames" , data = [entry ["frame" ] for entry in poses ])
367
- for i , bp in enumerate (bodyparts ):
368
- hf .create_dataset (
369
- name = f"{ bp } _x" ,
370
- data = [
371
- (
372
- entry ["pose" ]["poses" ][0 ][0 ][i , 0 ].item ()
373
- if isinstance (entry ["pose" ]["poses" ][0 ][0 ][i , 0 ], torch .Tensor )
374
- else entry ["pose" ]["poses" ][0 ][0 ][i , 0 ]
375
- )
376
- for entry in poses
377
- ],
378
- )
379
- hf .create_dataset (
380
- name = f"{ bp } _y" ,
381
- data = [
382
- (
383
- entry ["pose" ]["poses" ][0 ][0 ][i , 1 ].item ()
384
- if isinstance (entry ["pose" ]["poses" ][0 ][0 ][i , 1 ], torch .Tensor )
385
- else entry ["pose" ]["poses" ][0 ][0 ][i , 1 ]
386
- )
387
- for entry in poses
388
- ],
389
- )
390
- hf .create_dataset (
391
- name = f"{ bp } _confidence" ,
392
- data = [
393
- (
394
- entry ["pose" ]["poses" ][0 ][0 ][i , 2 ].item ()
395
- if isinstance (entry ["pose" ]["poses" ][0 ][0 ][i , 2 ], torch .Tensor )
396
- else entry ["pose" ]["poses" ][0 ][0 ][i , 2 ]
397
- )
398
- for entry in poses
399
- ],
400
- )
346
+ base_filename = Path (video_path ).stem
347
+ save_dir = Path (save_dir )
348
+ h5_save_path = save_dir / f"{ base_filename } _poses_{ timestamp } .h5"
349
+ csv_save_path = save_dir / f"{ base_filename } _poses_{ timestamp } .csv"
350
+
351
+ poses_array = _create_poses_np_array (n_individuals , bodyparts , poses )
352
+ flattened_poses = poses_array .reshape (poses_array .shape [0 ], - 1 )
353
+
354
+ if n_individuals == 1 :
355
+ pdindex = pd .MultiIndex .from_product (
356
+ [bodyparts , ["x" , "y" , "likelihood" ]], names = ["bodyparts" , "coords" ]
357
+ )
358
+ else :
359
+ individuals = [f"individual_{ i } " for i in range (n_individuals )]
360
+ pdindex = pd .MultiIndex .from_product (
361
+ [individuals , bodyparts , ["x" , "y" , "likelihood" ]], names = ["individuals" , "bodyparts" , "coords" ]
362
+ )
363
+
364
+ pose_df = pd .DataFrame (flattened_poses , columns = pdindex )
365
+
366
+ pose_df .to_hdf (h5_save_path , key = "df_with_missing" , mode = "w" )
367
+ pose_df .to_csv (csv_save_path , index = False )
368
+
369
+ def _create_poses_np_array (n_individuals : int , bodyparts : list , poses : list ):
370
+ # Create numpy array with poses:
371
+ max_frame = max (p ["frame" ] for p in poses )
372
+ pose_target_shape = (n_individuals , len (bodyparts ), 3 )
373
+ poses_array = np .full ((max_frame + 1 , * pose_target_shape ), np .nan )
374
+
375
+ for item in poses :
376
+ frame = item ["frame" ]
377
+ pose = item ["pose" ]
378
+ if pose .ndim == 2 :
379
+ pose = pose [np .newaxis , :, :]
380
+ padded_pose = np .full (pose_target_shape , np .nan )
381
+ slices = tuple (slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 ))
382
+ padded_pose [slices ] = pose [slices ]
383
+ poses_array [frame ] = padded_pose
384
+
385
+ return poses_array
401
386
402
387
403
388
import argparse
0 commit comments