Skip to content

Commit 98b3a13

Browse files
AnnaStuckertn-poulsen
authored andcommitted
add code to save numbers in csv and h5 as numbers, not tensor(number)
1 parent 2fb39fc commit 98b3a13

File tree

1 file changed

+332
-48
lines changed

1 file changed

+332
-48
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 332 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -292,54 +292,338 @@ def analyze_video(
292292

293293

294294
def save_poses_to_files(video_path, save_dir, bodyparts, poses):
295-
"""
296-
Save the keypoint poses detected in the video to CSV and HDF5 files.
295+
import csv
296+
import os
297+
import platform
298+
import subprocess
299+
import sys
300+
import time
301+
302+
import colorcet as cc
303+
import cv2
304+
import h5py
305+
import numpy as np
306+
import torch
307+
from PIL import ImageColor
308+
from pip._internal.operations import freeze
309+
310+
from dlclive import VERSION, DLCLive
311+
312+
def get_system_info() -> dict:
313+
"""Return summary info for system running benchmark.
314+
315+
Returns
316+
-------
317+
dict
318+
Dictionary containing the following system information:
319+
* ``host_name`` (str): name of machine
320+
* ``op_sys`` (str): operating system
321+
* ``python`` (str): path to python (which conda/virtual environment)
322+
* ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information)
323+
* ``freeze`` (list): list of installed packages and versions
324+
* ``python_version`` (str): python version
325+
* ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit
326+
* ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION`
327+
"""
328+
329+
# Get OS and host name
330+
op_sys = platform.platform()
331+
host_name = platform.node().replace(" ", "")
332+
333+
# Get Python executable path
334+
if platform.system() == "Windows":
335+
host_python = sys.executable.split(os.path.sep)[-2]
336+
else:
337+
host_python = sys.executable.split(os.path.sep)[-3]
338+
339+
# Try to get git hash if possible
340+
git_hash = None
341+
dlc_basedir = os.path.dirname(os.path.dirname(__file__))
342+
try:
343+
git_hash = (
344+
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir)
345+
.decode("utf-8")
346+
.strip()
347+
)
348+
except subprocess.CalledProcessError:
349+
# Not installed from git repo, e.g., pypi
350+
pass
351+
352+
# Get device info (GPU or CPU)
353+
if torch.cuda.is_available():
354+
dev_type = "GPU"
355+
dev = [torch.cuda.get_device_name(torch.cuda.current_device())]
356+
else:
357+
from cpuinfo import get_cpu_info
358+
359+
dev_type = "CPU"
360+
dev = [get_cpu_info()["brand_raw"]]
361+
362+
return {
363+
"host_name": host_name,
364+
"op_sys": op_sys,
365+
"python": host_python,
366+
"device_type": dev_type,
367+
"device": dev,
368+
"freeze": list(freeze.freeze()),
369+
"python_version": sys.version,
370+
"git_hash": git_hash,
371+
"dlclive_version": VERSION,
372+
}
373+
374+
def analyze_video(
375+
video_path: str,
376+
model_path: str,
377+
model_type: str,
378+
device: str,
379+
precision: str = "FP32",
380+
snapshot: str = None,
381+
display=True,
382+
pcutoff=0.5,
383+
display_radius=5,
384+
resize=None,
385+
cropping=None, # Adding cropping to the function parameters
386+
dynamic=(False, 0.5, 10),
387+
save_poses=False,
388+
save_dir="model_predictions",
389+
draw_keypoint_names=False,
390+
cmap="bmy",
391+
get_sys_info=True,
392+
save_video=False,
393+
):
394+
"""
395+
Analyze a video to track keypoints using an imported DeepLabCut model, visualize keypoints on the video, and optionally save the keypoint data and the labelled video.
396+
397+
Parameters:
398+
-----------
399+
video_path : str
400+
The path to the video file to be analyzed.
401+
dlc_live : DLCLive
402+
An instance of the DLCLive class.
403+
pcutoff : float, optional, default=0.5
404+
The probability cutoff value below which keypoints are not visualized.
405+
display_radius : int, optional, default=5
406+
The radius of the circles drawn to represent keypoints on the video frames.
407+
resize : tuple of int (width, height) or None, optional, default=None
408+
The size to which the frames should be resized. If None, the frames are not resized.
409+
cropping : list of int, optional, default=None
410+
Cropping parameters in pixel number: [x1, x2, y1, y2]
411+
save_poses : bool, optional, default=False
412+
Whether to save the detected poses to CSV and HDF5 files.
413+
save_dir : str, optional, default="model_predictions"
414+
The directory where the output video and pose data will be saved.
415+
draw_keypoint_names : bool, optional, default=False
416+
Whether to draw the names of the keypoints on the video frames.
417+
cmap : str, optional, default="bmy"
418+
The colormap from the colorcet library to use for keypoint visualization.
419+
420+
Returns:
421+
--------
422+
poses : list of dict
423+
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
424+
"""
425+
# Create the DLCLive object with cropping
426+
dlc_live = DLCLive(
427+
path=model_path,
428+
model_type=model_type,
429+
device=device,
430+
display=display,
431+
resize=resize,
432+
cropping=cropping, # Pass the cropping parameter
433+
dynamic=dynamic,
434+
precision=precision,
435+
snapshot=snapshot,
436+
)
297437

298-
Parameters:
299-
-----------
300-
video_path : str
301-
The path to the video file that was analyzed.
302-
save_dir : str
303-
The directory where the pose data files will be saved.
304-
bodyparts : list of str
305-
A list of body part names corresponding to the keypoints.
306-
poses : list of dict
307-
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
438+
# Ensure save directory exists
439+
os.makedirs(name=save_dir, exist_ok=True)
308440

309-
Returns:
310-
--------
311-
None
312-
"""
313-
base_filename = os.path.splitext(os.path.basename(video_path))[0]
314-
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses.csv")
315-
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses.h5")
316-
317-
# Save to CSV
318-
with open(csv_save_path, mode="w", newline="") as file:
319-
writer = csv.writer(file)
320-
header = ["frame"] + [
321-
f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"]
322-
]
323-
writer.writerow(header)
324-
for entry in poses:
325-
frame_num = entry["frame"]
326-
pose = entry["pose"]["poses"][0][0]
327-
row = [frame_num] + [item for kp in pose for item in kp]
328-
writer.writerow(row)
329-
330-
# Save to HDF5
331-
with h5py.File(h5_save_path, "w") as hf:
332-
hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses])
333-
for i, bp in enumerate(bodyparts):
334-
hf.create_dataset(
335-
name=f"{bp}_x",
336-
data=[entry["pose"]["poses"][0][0][i, 0].item() for entry in poses],
337-
)
338-
hf.create_dataset(
339-
name=f"{bp}_y",
340-
data=[entry["pose"]["poses"][0][0][i, 1].item() for entry in poses],
341-
)
342-
hf.create_dataset(
343-
name=f"{bp}_confidence",
344-
data=[entry["pose"]["poses"][0][0][i, 2].item() for entry in poses],
441+
# Load video
442+
cap = cv2.VideoCapture(video_path)
443+
if not cap.isOpened():
444+
print(f"Error: Could not open video file {video_path}")
445+
return
446+
447+
# Start empty dict to save poses to for each frame
448+
poses, times = [], []
449+
# Create variable indicate current frame. Later in the code +1 is added to frame_index
450+
frame_index = 0
451+
452+
# Retrieve bodypart names and number of keypoints
453+
bodyparts = dlc_live.cfg["metadata"]["bodyparts"]
454+
num_keypoints = len(bodyparts)
455+
456+
if save_video:
457+
# Set colors and convert to RGB
458+
cmap_colors = getattr(cc, cmap)
459+
colors = [
460+
ImageColor.getrgb(color)
461+
for color in cmap_colors[:: int(len(cmap_colors) / num_keypoints)]
462+
]
463+
464+
# Define output video path
465+
video_name = os.path.splitext(os.path.basename(video_path))[0]
466+
output_video_path = os.path.join(
467+
save_dir, f"{video_name}_DLCLIVE_LABELLED.mp4"
468+
)
469+
470+
# Get video writer setup
471+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
472+
fps = cap.get(cv2.CAP_PROP_FPS)
473+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
474+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
475+
476+
vwriter = cv2.VideoWriter(
477+
filename=output_video_path,
478+
fourcc=fourcc,
479+
fps=fps,
480+
frameSize=(frame_width, frame_height),
345481
)
482+
483+
while True:
484+
start_time = time.time()
485+
486+
ret, frame = cap.read()
487+
if not ret:
488+
break
489+
# if frame_index == 0:
490+
# pose = dlc_live.init_inference(frame) # load DLC model
491+
try:
492+
# pose = dlc_live.get_pose(frame)
493+
if frame_index == 0:
494+
# dlc_live.dynamic = (False, dynamic[1], dynamic[2]) # TODO trying to fix issues with dynamic cropping jumping back and forth between dyanmic cropped and original image
495+
pose, inf_time = dlc_live.init_inference(frame) # load DLC model
496+
else:
497+
# dlc_live.dynamic = dynamic
498+
pose, inf_time = dlc_live.get_pose(frame)
499+
except Exception as e:
500+
print(f"Error analyzing frame {frame_index}: {e}")
501+
continue
502+
503+
poses.append({"frame": frame_index, "pose": pose})
504+
times.append(inf_time)
505+
506+
if save_video:
507+
# Visualize keypoints
508+
this_pose = pose["poses"][0][0]
509+
for j in range(this_pose.shape[0]):
510+
if this_pose[j, 2] > pcutoff:
511+
x, y = map(int, this_pose[j, :2])
512+
cv2.circle(
513+
frame,
514+
center=(x, y),
515+
radius=display_radius,
516+
color=colors[j],
517+
thickness=-1,
518+
)
519+
520+
if draw_keypoint_names:
521+
cv2.putText(
522+
frame,
523+
text=bodyparts[j],
524+
org=(x + 10, y),
525+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
526+
fontScale=0.5,
527+
color=colors[j],
528+
thickness=1,
529+
lineType=cv2.LINE_AA,
530+
)
531+
532+
vwriter.write(image=frame)
533+
frame_index += 1
534+
535+
cap.release()
536+
if save_video:
537+
vwriter.release()
538+
539+
if get_sys_info:
540+
print(get_system_info())
541+
542+
if save_poses:
543+
save_poses_to_files(video_path, save_dir, bodyparts, poses)
544+
545+
return poses, times
546+
547+
def save_poses_to_files(video_path, save_dir, bodyparts, poses):
548+
"""
549+
Save the keypoint poses detected in the video to CSV and HDF5 files.
550+
551+
Parameters:
552+
-----------
553+
video_path : str
554+
The path to the video file that was analyzed.
555+
save_dir : str
556+
The directory where the pose data files will be saved.
557+
bodyparts : list of str
558+
A list of body part names corresponding to the keypoints.
559+
poses : list of dict
560+
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
561+
562+
Returns:
563+
--------
564+
None
565+
"""
566+
base_filename = os.path.splitext(os.path.basename(video_path))[0]
567+
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses.csv")
568+
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses.h5")
569+
570+
# Save to CSV
571+
with open(csv_save_path, mode="w", newline="") as file:
572+
writer = csv.writer(file)
573+
header = ["frame"] + [
574+
f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"]
575+
]
576+
writer.writerow(header)
577+
for entry in poses:
578+
frame_num = entry["frame"]
579+
pose = entry["pose"]["poses"][0][0]
580+
row = [frame_num] + [
581+
item.item() if isinstance(item, torch.Tensor) else item
582+
for kp in pose
583+
for item in kp
584+
]
585+
writer.writerow(row)
586+
587+
# Save to HDF5
588+
with h5py.File(h5_save_path, "w") as hf:
589+
hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses])
590+
for i, bp in enumerate(bodyparts):
591+
hf.create_dataset(
592+
name=f"{bp}_x",
593+
data=[
594+
(
595+
entry["pose"]["poses"][0][0][i, 0].item()
596+
if isinstance(
597+
entry["pose"]["poses"][0][0][i, 0], torch.Tensor
598+
)
599+
else entry["pose"]["poses"][0][0][i, 0]
600+
)
601+
for entry in poses
602+
],
603+
)
604+
hf.create_dataset(
605+
name=f"{bp}_y",
606+
data=[
607+
(
608+
entry["pose"]["poses"][0][0][i, 1].item()
609+
if isinstance(
610+
entry["pose"]["poses"][0][0][i, 1], torch.Tensor
611+
)
612+
else entry["pose"]["poses"][0][0][i, 1]
613+
)
614+
for entry in poses
615+
],
616+
)
617+
hf.create_dataset(
618+
name=f"{bp}_confidence",
619+
data=[
620+
(
621+
entry["pose"]["poses"][0][0][i, 2].item()
622+
if isinstance(
623+
entry["pose"]["poses"][0][0][i, 2], torch.Tensor
624+
)
625+
else entry["pose"]["poses"][0][0][i, 2]
626+
)
627+
for entry in poses
628+
],
629+
)

0 commit comments

Comments
 (0)