@@ -292,54 +292,338 @@ def analyze_video(
292292
293293
294294def save_poses_to_files (video_path , save_dir , bodyparts , poses ):
295- """
296- Save the keypoint poses detected in the video to CSV and HDF5 files.
295+ import csv
296+ import os
297+ import platform
298+ import subprocess
299+ import sys
300+ import time
301+
302+ import colorcet as cc
303+ import cv2
304+ import h5py
305+ import numpy as np
306+ import torch
307+ from PIL import ImageColor
308+ from pip ._internal .operations import freeze
309+
310+ from dlclive import VERSION , DLCLive
311+
312+ def get_system_info () -> dict :
313+ """Return summary info for system running benchmark.
314+
315+ Returns
316+ -------
317+ dict
318+ Dictionary containing the following system information:
319+ * ``host_name`` (str): name of machine
320+ * ``op_sys`` (str): operating system
321+ * ``python`` (str): path to python (which conda/virtual environment)
322+ * ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information)
323+ * ``freeze`` (list): list of installed packages and versions
324+ * ``python_version`` (str): python version
325+ * ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit
326+ * ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION`
327+ """
328+
329+ # Get OS and host name
330+ op_sys = platform .platform ()
331+ host_name = platform .node ().replace (" " , "" )
332+
333+ # Get Python executable path
334+ if platform .system () == "Windows" :
335+ host_python = sys .executable .split (os .path .sep )[- 2 ]
336+ else :
337+ host_python = sys .executable .split (os .path .sep )[- 3 ]
338+
339+ # Try to get git hash if possible
340+ git_hash = None
341+ dlc_basedir = os .path .dirname (os .path .dirname (__file__ ))
342+ try :
343+ git_hash = (
344+ subprocess .check_output (["git" , "rev-parse" , "HEAD" ], cwd = dlc_basedir )
345+ .decode ("utf-8" )
346+ .strip ()
347+ )
348+ except subprocess .CalledProcessError :
349+ # Not installed from git repo, e.g., pypi
350+ pass
351+
352+ # Get device info (GPU or CPU)
353+ if torch .cuda .is_available ():
354+ dev_type = "GPU"
355+ dev = [torch .cuda .get_device_name (torch .cuda .current_device ())]
356+ else :
357+ from cpuinfo import get_cpu_info
358+
359+ dev_type = "CPU"
360+ dev = [get_cpu_info ()["brand_raw" ]]
361+
362+ return {
363+ "host_name" : host_name ,
364+ "op_sys" : op_sys ,
365+ "python" : host_python ,
366+ "device_type" : dev_type ,
367+ "device" : dev ,
368+ "freeze" : list (freeze .freeze ()),
369+ "python_version" : sys .version ,
370+ "git_hash" : git_hash ,
371+ "dlclive_version" : VERSION ,
372+ }
373+
374+ def analyze_video (
375+ video_path : str ,
376+ model_path : str ,
377+ model_type : str ,
378+ device : str ,
379+ precision : str = "FP32" ,
380+ snapshot : str = None ,
381+ display = True ,
382+ pcutoff = 0.5 ,
383+ display_radius = 5 ,
384+ resize = None ,
385+ cropping = None , # Adding cropping to the function parameters
386+ dynamic = (False , 0.5 , 10 ),
387+ save_poses = False ,
388+ save_dir = "model_predictions" ,
389+ draw_keypoint_names = False ,
390+ cmap = "bmy" ,
391+ get_sys_info = True ,
392+ save_video = False ,
393+ ):
394+ """
395+ Analyze a video to track keypoints using an imported DeepLabCut model, visualize keypoints on the video, and optionally save the keypoint data and the labelled video.
396+
397+ Parameters:
398+ -----------
399+ video_path : str
400+ The path to the video file to be analyzed.
401+ dlc_live : DLCLive
402+ An instance of the DLCLive class.
403+ pcutoff : float, optional, default=0.5
404+ The probability cutoff value below which keypoints are not visualized.
405+ display_radius : int, optional, default=5
406+ The radius of the circles drawn to represent keypoints on the video frames.
407+ resize : tuple of int (width, height) or None, optional, default=None
408+ The size to which the frames should be resized. If None, the frames are not resized.
409+ cropping : list of int, optional, default=None
410+ Cropping parameters in pixel number: [x1, x2, y1, y2]
411+ save_poses : bool, optional, default=False
412+ Whether to save the detected poses to CSV and HDF5 files.
413+ save_dir : str, optional, default="model_predictions"
414+ The directory where the output video and pose data will be saved.
415+ draw_keypoint_names : bool, optional, default=False
416+ Whether to draw the names of the keypoints on the video frames.
417+ cmap : str, optional, default="bmy"
418+ The colormap from the colorcet library to use for keypoint visualization.
419+
420+ Returns:
421+ --------
422+ poses : list of dict
423+ A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
424+ """
425+ # Create the DLCLive object with cropping
426+ dlc_live = DLCLive (
427+ path = model_path ,
428+ model_type = model_type ,
429+ device = device ,
430+ display = display ,
431+ resize = resize ,
432+ cropping = cropping , # Pass the cropping parameter
433+ dynamic = dynamic ,
434+ precision = precision ,
435+ snapshot = snapshot ,
436+ )
297437
298- Parameters:
299- -----------
300- video_path : str
301- The path to the video file that was analyzed.
302- save_dir : str
303- The directory where the pose data files will be saved.
304- bodyparts : list of str
305- A list of body part names corresponding to the keypoints.
306- poses : list of dict
307- A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
438+ # Ensure save directory exists
439+ os .makedirs (name = save_dir , exist_ok = True )
308440
309- Returns:
310- --------
311- None
312- """
313- base_filename = os .path .splitext (os .path .basename (video_path ))[0 ]
314- csv_save_path = os .path .join (save_dir , f"{ base_filename } _poses.csv" )
315- h5_save_path = os .path .join (save_dir , f"{ base_filename } _poses.h5" )
316-
317- # Save to CSV
318- with open (csv_save_path , mode = "w" , newline = "" ) as file :
319- writer = csv .writer (file )
320- header = ["frame" ] + [
321- f"{ bp } _{ axis } " for bp in bodyparts for axis in ["x" , "y" , "confidence" ]
322- ]
323- writer .writerow (header )
324- for entry in poses :
325- frame_num = entry ["frame" ]
326- pose = entry ["pose" ]["poses" ][0 ][0 ]
327- row = [frame_num ] + [item for kp in pose for item in kp ]
328- writer .writerow (row )
329-
330- # Save to HDF5
331- with h5py .File (h5_save_path , "w" ) as hf :
332- hf .create_dataset (name = "frames" , data = [entry ["frame" ] for entry in poses ])
333- for i , bp in enumerate (bodyparts ):
334- hf .create_dataset (
335- name = f"{ bp } _x" ,
336- data = [entry ["pose" ]["poses" ][0 ][0 ][i , 0 ].item () for entry in poses ],
337- )
338- hf .create_dataset (
339- name = f"{ bp } _y" ,
340- data = [entry ["pose" ]["poses" ][0 ][0 ][i , 1 ].item () for entry in poses ],
341- )
342- hf .create_dataset (
343- name = f"{ bp } _confidence" ,
344- data = [entry ["pose" ]["poses" ][0 ][0 ][i , 2 ].item () for entry in poses ],
441+ # Load video
442+ cap = cv2 .VideoCapture (video_path )
443+ if not cap .isOpened ():
444+ print (f"Error: Could not open video file { video_path } " )
445+ return
446+
447+ # Start empty dict to save poses to for each frame
448+ poses , times = [], []
449+ # Create variable indicate current frame. Later in the code +1 is added to frame_index
450+ frame_index = 0
451+
452+ # Retrieve bodypart names and number of keypoints
453+ bodyparts = dlc_live .cfg ["metadata" ]["bodyparts" ]
454+ num_keypoints = len (bodyparts )
455+
456+ if save_video :
457+ # Set colors and convert to RGB
458+ cmap_colors = getattr (cc , cmap )
459+ colors = [
460+ ImageColor .getrgb (color )
461+ for color in cmap_colors [:: int (len (cmap_colors ) / num_keypoints )]
462+ ]
463+
464+ # Define output video path
465+ video_name = os .path .splitext (os .path .basename (video_path ))[0 ]
466+ output_video_path = os .path .join (
467+ save_dir , f"{ video_name } _DLCLIVE_LABELLED.mp4"
468+ )
469+
470+ # Get video writer setup
471+ fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
472+ fps = cap .get (cv2 .CAP_PROP_FPS )
473+ frame_width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
474+ frame_height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
475+
476+ vwriter = cv2 .VideoWriter (
477+ filename = output_video_path ,
478+ fourcc = fourcc ,
479+ fps = fps ,
480+ frameSize = (frame_width , frame_height ),
345481 )
482+
483+ while True :
484+ start_time = time .time ()
485+
486+ ret , frame = cap .read ()
487+ if not ret :
488+ break
489+ # if frame_index == 0:
490+ # pose = dlc_live.init_inference(frame) # load DLC model
491+ try :
492+ # pose = dlc_live.get_pose(frame)
493+ if frame_index == 0 :
494+ # dlc_live.dynamic = (False, dynamic[1], dynamic[2]) # TODO trying to fix issues with dynamic cropping jumping back and forth between dyanmic cropped and original image
495+ pose , inf_time = dlc_live .init_inference (frame ) # load DLC model
496+ else :
497+ # dlc_live.dynamic = dynamic
498+ pose , inf_time = dlc_live .get_pose (frame )
499+ except Exception as e :
500+ print (f"Error analyzing frame { frame_index } : { e } " )
501+ continue
502+
503+ poses .append ({"frame" : frame_index , "pose" : pose })
504+ times .append (inf_time )
505+
506+ if save_video :
507+ # Visualize keypoints
508+ this_pose = pose ["poses" ][0 ][0 ]
509+ for j in range (this_pose .shape [0 ]):
510+ if this_pose [j , 2 ] > pcutoff :
511+ x , y = map (int , this_pose [j , :2 ])
512+ cv2 .circle (
513+ frame ,
514+ center = (x , y ),
515+ radius = display_radius ,
516+ color = colors [j ],
517+ thickness = - 1 ,
518+ )
519+
520+ if draw_keypoint_names :
521+ cv2 .putText (
522+ frame ,
523+ text = bodyparts [j ],
524+ org = (x + 10 , y ),
525+ fontFace = cv2 .FONT_HERSHEY_SIMPLEX ,
526+ fontScale = 0.5 ,
527+ color = colors [j ],
528+ thickness = 1 ,
529+ lineType = cv2 .LINE_AA ,
530+ )
531+
532+ vwriter .write (image = frame )
533+ frame_index += 1
534+
535+ cap .release ()
536+ if save_video :
537+ vwriter .release ()
538+
539+ if get_sys_info :
540+ print (get_system_info ())
541+
542+ if save_poses :
543+ save_poses_to_files (video_path , save_dir , bodyparts , poses )
544+
545+ return poses , times
546+
547+ def save_poses_to_files (video_path , save_dir , bodyparts , poses ):
548+ """
549+ Save the keypoint poses detected in the video to CSV and HDF5 files.
550+
551+ Parameters:
552+ -----------
553+ video_path : str
554+ The path to the video file that was analyzed.
555+ save_dir : str
556+ The directory where the pose data files will be saved.
557+ bodyparts : list of str
558+ A list of body part names corresponding to the keypoints.
559+ poses : list of dict
560+ A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
561+
562+ Returns:
563+ --------
564+ None
565+ """
566+ base_filename = os .path .splitext (os .path .basename (video_path ))[0 ]
567+ csv_save_path = os .path .join (save_dir , f"{ base_filename } _poses.csv" )
568+ h5_save_path = os .path .join (save_dir , f"{ base_filename } _poses.h5" )
569+
570+ # Save to CSV
571+ with open (csv_save_path , mode = "w" , newline = "" ) as file :
572+ writer = csv .writer (file )
573+ header = ["frame" ] + [
574+ f"{ bp } _{ axis } " for bp in bodyparts for axis in ["x" , "y" , "confidence" ]
575+ ]
576+ writer .writerow (header )
577+ for entry in poses :
578+ frame_num = entry ["frame" ]
579+ pose = entry ["pose" ]["poses" ][0 ][0 ]
580+ row = [frame_num ] + [
581+ item .item () if isinstance (item , torch .Tensor ) else item
582+ for kp in pose
583+ for item in kp
584+ ]
585+ writer .writerow (row )
586+
587+ # Save to HDF5
588+ with h5py .File (h5_save_path , "w" ) as hf :
589+ hf .create_dataset (name = "frames" , data = [entry ["frame" ] for entry in poses ])
590+ for i , bp in enumerate (bodyparts ):
591+ hf .create_dataset (
592+ name = f"{ bp } _x" ,
593+ data = [
594+ (
595+ entry ["pose" ]["poses" ][0 ][0 ][i , 0 ].item ()
596+ if isinstance (
597+ entry ["pose" ]["poses" ][0 ][0 ][i , 0 ], torch .Tensor
598+ )
599+ else entry ["pose" ]["poses" ][0 ][0 ][i , 0 ]
600+ )
601+ for entry in poses
602+ ],
603+ )
604+ hf .create_dataset (
605+ name = f"{ bp } _y" ,
606+ data = [
607+ (
608+ entry ["pose" ]["poses" ][0 ][0 ][i , 1 ].item ()
609+ if isinstance (
610+ entry ["pose" ]["poses" ][0 ][0 ][i , 1 ], torch .Tensor
611+ )
612+ else entry ["pose" ]["poses" ][0 ][0 ][i , 1 ]
613+ )
614+ for entry in poses
615+ ],
616+ )
617+ hf .create_dataset (
618+ name = f"{ bp } _confidence" ,
619+ data = [
620+ (
621+ entry ["pose" ]["poses" ][0 ][0 ][i , 2 ].item ()
622+ if isinstance (
623+ entry ["pose" ]["poses" ][0 ][0 ][i , 2 ], torch .Tensor
624+ )
625+ else entry ["pose" ]["poses" ][0 ][0 ][i , 2 ]
626+ )
627+ for entry in poses
628+ ],
629+ )
0 commit comments