Skip to content

Commit e6a914a

Browse files
AnnaStuckertn-poulsen
authored andcommitted
add timestamp suffix to videos and csv/h5 files
1 parent 98b3a13 commit e6a914a

File tree

2 files changed

+350
-472
lines changed

2 files changed

+350
-472
lines changed

dlclive/LiveVideoInference.py

Lines changed: 87 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@
1717

1818

1919
def get_system_info() -> dict:
20-
"""Return summary info for system running benchmark.
20+
"""
21+
Returns a summary of system information relevant to running benchmarking.
2122
2223
Returns
2324
-------
2425
dict
25-
Dictionary containing the following system information:
26-
* ``host_name`` (str): name of machine
27-
* ``op_sys`` (str): operating system
28-
* ``python`` (str): path to python (which conda/virtual environment)
29-
* ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information)
30-
* ``freeze`` (list): list of installed packages and versions
31-
* ``python_version`` (str): python version
32-
* ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit
33-
* ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION`
26+
A dictionary containing the following system information:
27+
- host_name (str): Name of the machine.
28+
- op_sys (str): Operating system.
29+
- python (str): Path to the Python executable, indicating the conda/virtual environment in use.
30+
- device_type (str): Type of device used ('GPU' or 'CPU').
31+
- device (list): List containing the name of the GPU or CPU brand.
32+
- freeze (list): List of installed Python packages with their versions.
33+
- python_version (str): Version of Python in use.
34+
- git_hash (str or None): If installed from git repository, hash of HEAD commit.
35+
- dlclive_version (str): Version of the DLCLive package.
3436
"""
3537

3638
# Get OS and host name
@@ -101,35 +103,55 @@ def analyze_live_video(
101103
save_video=False,
102104
):
103105
"""
104-
Analyze a video to track keypoints using an imported DeepLabCut model, visualize keypoints on the video, and optionally save the keypoint data and the labelled video.
105-
106-
Parameters:
107-
-----------
106+
Analyzes a video to track keypoints using a DeepLabCut model, and optionally saves the keypoint data and the labeled video.
107+
108+
Parameters
109+
----------
110+
model_path : str
111+
Path to the DeepLabCut model.
112+
model_type : str
113+
Type of the model (e.g., 'onnx').
114+
device : str
115+
Device to run the model on ('cpu' or 'cuda').
108116
camera : float, default=0 (webcam)
109-
The camera to record the live video from
117+
The camera to record the live video from.
110118
experiment_name : str, default = "Test"
111119
Prefix to label generated pose and video files
120+
precision : str, optional, default='FP32'
121+
Precision type for the model ('FP32' or 'FP16').
122+
snapshot : str, optional
123+
Snapshot to use for the model, if using pytorch as model type.
124+
display : bool, optional, default=True
125+
Whether to display frame with labelled key points.
112126
pcutoff : float, optional, default=0.5
113-
The probability cutoff value below which keypoints are not visualized.
127+
Probability cutoff below which keypoints are not visualized.
114128
display_radius : int, optional, default=5
115-
The radius of the circles drawn to represent keypoints on the video frames.
116-
resize : tuple of int (width, height) or None, optional, default=None
117-
The size to which the frames should be resized. If None, the frames are not resized.
118-
cropping : list of int, optional, default=None
119-
Cropping parameters in pixel number: [x1, x2, y1, y2]
129+
Radius of circles drawn for keypoints on video frames.
130+
resize : tuple of int (width, height) or None, optional
131+
Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied.
132+
cropping : list of int or None, optional
133+
Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied.
134+
dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin)
135+
Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. <detection treshold). The current position is utilized for updating the crop window for the next frame (this is why the margin is important and should be set large enough given the movement of the animal).
120136
save_poses : bool, optional, default=False
121137
Whether to save the detected poses to CSV and HDF5 files.
122-
save_dir : str, optional, default="model_predictions"
123-
The directory where the output video and pose data will be saved.
138+
save_dir : str, optional, default='model_predictions'
139+
Directory to save output data and labeled video.
124140
draw_keypoint_names : bool, optional, default=False
125-
Whether to draw the names of the keypoints on the video frames.
126-
cmap : str, optional, default="bmy"
127-
The colormap from the colorcet library to use for keypoint visualization.
141+
Whether to display keypoint names on video frames in the saved video.
142+
cmap : str, optional, default='bmy'
143+
Colormap from the colorcet library for keypoint visualization.
144+
get_sys_info : bool, optional, default=True
145+
Whether to print system information.
146+
save_video : bool, optional, default=False
147+
Whether to save the labeled video.
128148
129-
Returns:
130-
--------
131-
poses : list of dict
132-
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
149+
Returns
150+
-------
151+
tuple
152+
A tuple containing:
153+
- poses (list of dict): List of pose data for each frame.
154+
- times (list of float): List of inference times for each frame.
133155
"""
134156
# Create the DLCLive object with cropping
135157
dlc_live = DLCLive(
@@ -147,6 +169,9 @@ def analyze_live_video(
147169
# Ensure save directory exists
148170
os.makedirs(name=save_dir, exist_ok=True)
149171

172+
# Get the current date and time as a string
173+
timestamp = time.strftime("%Y%m%d_%H%M%S")
174+
150175
# Load video
151176
cap = cv2.VideoCapture(camera)
152177
if not cap.isOpened():
@@ -171,7 +196,7 @@ def analyze_live_video(
171196

172197
# Define output video path
173198
output_video_path = os.path.join(
174-
save_dir, f"{experiment_name}_DLCLIVE_LABELLED.mp4"
199+
save_dir, f"{experiment_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
175200
)
176201

177202
# Get video writer setup
@@ -188,29 +213,26 @@ def analyze_live_video(
188213
)
189214

190215
while True:
191-
start_time = time.time()
192216

193217
ret, frame = cap.read()
194218
if not ret:
195219
break
196220

197221
try:
198222
if frame_index == 0:
199-
pose = dlc_live.init_inference(frame) # load DLC model
223+
pose, inf_time = dlc_live.init_inference(frame) # load DLC model
200224
else:
201-
pose = dlc_live.get_pose(frame)
225+
pose, inf_time = dlc_live.get_pose(frame)
202226
except Exception as e:
203227
print(f"Error analyzing frame {frame_index}: {e}")
204228
continue
205229

206-
end_time = time.time()
207-
processing_time = end_time - start_time
208-
print(f"Frame {frame_index} processing time: {processing_time:.4f} seconds")
209-
210230
poses.append({"frame": frame_index, "pose": pose})
231+
times.append(inf_time)
232+
211233
if save_video:
212234
# Visualize keypoints
213-
this_pose = pose[0]["poses"][0][0]
235+
this_pose = pose["poses"][0][0]
214236
for j in range(this_pose.shape[0]):
215237
if this_pose[j, 2] > pcutoff:
216238
x, y = map(int, this_pose[j, :2])
@@ -255,33 +277,35 @@ def analyze_live_video(
255277
print(get_system_info())
256278

257279
if save_poses:
258-
save_poses_to_files(experiment_name, save_dir, bodyparts, poses)
280+
save_poses_to_files(
281+
experiment_name, save_dir, bodyparts, poses, timestamp=timestamp
282+
)
259283

260284
return poses, times
261285

262286

263-
def save_poses_to_files(experiment_name, save_dir, bodyparts, poses):
287+
def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp):
264288
"""
265-
Save the keypoint poses detected in the video to CSV and HDF5 files.
289+
Saves the detected keypoint poses from the video to CSV and HDF5 files.
266290
267-
Parameters:
268-
-----------
269-
experiment_name : str
270-
Name of the experiment, used as a prefix for saving files.
291+
Parameters
292+
----------
293+
video_path : str
294+
Path to the analyzed video file.
271295
save_dir : str
272-
The directory where the pose data files will be saved.
296+
Directory where the pose data files will be saved.
273297
bodyparts : list of str
274-
A list of body part names corresponding to the keypoints.
298+
List of body part names corresponding to the keypoints.
275299
poses : list of dict
276-
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
300+
List of dictionaries containing frame numbers and corresponding pose data.
277301
278-
Returns:
279-
--------
302+
Returns
303+
-------
280304
None
281305
"""
282306
base_filename = os.path.splitext(os.path.basename(experiment_name))[0]
283-
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses.csv")
284-
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses.h5")
307+
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv")
308+
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5")
285309

286310
# Save to CSV
287311
with open(csv_save_path, mode="w", newline="") as file:
@@ -292,7 +316,7 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses):
292316
writer.writerow(header)
293317
for entry in poses:
294318
frame_num = entry["frame"]
295-
pose_data = entry["pose"][0]["poses"][0][0]
319+
pose_data = entry["pose"]["poses"][0][0]
296320
# Convert tensor data to numeric values
297321
row = [frame_num] + [
298322
item.item() if isinstance(item, torch.Tensor) else item
@@ -309,11 +333,9 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses):
309333
name=f"{bp}_x",
310334
data=[
311335
(
312-
entry["pose"][0]["poses"][0][0][i, 0].item()
313-
if isinstance(
314-
entry["pose"][0]["poses"][0][0][i, 0], torch.Tensor
315-
)
316-
else entry["pose"][0]["poses"][0][0][i, 0]
336+
entry["pose"]["poses"][0][0][i, 0].item()
337+
if isinstance(entry["pose"]["poses"][0][0][i, 0], torch.Tensor)
338+
else entry["pose"]["poses"][0][0][i, 0]
317339
)
318340
for entry in poses
319341
],
@@ -322,11 +344,9 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses):
322344
name=f"{bp}_y",
323345
data=[
324346
(
325-
entry["pose"][0]["poses"][0][0][i, 1].item()
326-
if isinstance(
327-
entry["pose"][0]["poses"][0][0][i, 1], torch.Tensor
328-
)
329-
else entry["pose"][0]["poses"][0][0][i, 1]
347+
entry["pose"]["poses"][0][0][i, 1].item()
348+
if isinstance(entry["pose"]["poses"][0][0][i, 1], torch.Tensor)
349+
else entry["pose"]["poses"][0][0][i, 1]
330350
)
331351
for entry in poses
332352
],
@@ -335,11 +355,9 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses):
335355
name=f"{bp}_confidence",
336356
data=[
337357
(
338-
entry["pose"][0]["poses"][0][0][i, 2].item()
339-
if isinstance(
340-
entry["pose"][0]["poses"][0][0][i, 2], torch.Tensor
341-
)
342-
else entry["pose"][0]["poses"][0][0][i, 2]
358+
entry["pose"]["poses"][0][0][i, 2].item()
359+
if isinstance(entry["pose"]["poses"][0][0][i, 2], torch.Tensor)
360+
else entry["pose"]["poses"][0][0][i, 2]
343361
)
344362
for entry in poses
345363
],

0 commit comments

Comments
 (0)