Skip to content

Commit 5a7a976

Browse files
committed
Benchmark pytorch: fix save_poses_to_files()
1 parent b493bef commit 5a7a976

File tree

1 file changed

+45
-60
lines changed

1 file changed

+45
-60
lines changed

dlclive/benchmark_pytorch.py

Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def benchmark(
243243
print(get_system_info())
244244

245245
if save_poses:
246-
save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp=timestamp)
246+
individuals = dlc_live.read_config()["metadata"].get("individuals", [])
247+
n_individuals = len(individuals) or 1
248+
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
247249

248250
return poses, times
249251

@@ -320,7 +322,7 @@ def draw_pose_and_write(
320322

321323
vwriter.write(image=frame)
322324

323-
def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
325+
def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
324326
"""
325327
Saves the detected keypoint poses from the video to CSV and HDF5 files.
326328
@@ -339,65 +341,48 @@ def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp):
339341
-------
340342
None
341343
"""
344+
import pandas as pd
342345

343-
base_filename = os.path.splitext(os.path.basename(video_path))[0]
344-
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv")
345-
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5")
346-
347-
# Save to CSV
348-
with open(csv_save_path, mode="w", newline="") as file:
349-
writer = csv.writer(file)
350-
header = ["frame"] + [
351-
f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"]
352-
]
353-
writer.writerow(header)
354-
for entry in poses:
355-
frame_num = entry["frame"]
356-
pose = entry["pose"]["poses"][0][0]
357-
row = [frame_num] + [
358-
item.item() if isinstance(item, torch.Tensor) else item
359-
for kp in pose
360-
for item in kp
361-
]
362-
writer.writerow(row)
363-
364-
# Save to HDF5
365-
with h5py.File(h5_save_path, "w") as hf:
366-
hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses])
367-
for i, bp in enumerate(bodyparts):
368-
hf.create_dataset(
369-
name=f"{bp}_x",
370-
data=[
371-
(
372-
entry["pose"]["poses"][0][0][i, 0].item()
373-
if isinstance(entry["pose"]["poses"][0][0][i, 0], torch.Tensor)
374-
else entry["pose"]["poses"][0][0][i, 0]
375-
)
376-
for entry in poses
377-
],
378-
)
379-
hf.create_dataset(
380-
name=f"{bp}_y",
381-
data=[
382-
(
383-
entry["pose"]["poses"][0][0][i, 1].item()
384-
if isinstance(entry["pose"]["poses"][0][0][i, 1], torch.Tensor)
385-
else entry["pose"]["poses"][0][0][i, 1]
386-
)
387-
for entry in poses
388-
],
389-
)
390-
hf.create_dataset(
391-
name=f"{bp}_confidence",
392-
data=[
393-
(
394-
entry["pose"]["poses"][0][0][i, 2].item()
395-
if isinstance(entry["pose"]["poses"][0][0][i, 2], torch.Tensor)
396-
else entry["pose"]["poses"][0][0][i, 2]
397-
)
398-
for entry in poses
399-
],
400-
)
346+
base_filename = Path(video_path).stem
347+
save_dir = Path(save_dir)
348+
h5_save_path = save_dir / f"{base_filename}_poses_{timestamp}.h5"
349+
csv_save_path = save_dir / f"{base_filename}_poses_{timestamp}.csv"
350+
351+
poses_array = _create_poses_np_array(n_individuals, bodyparts, poses)
352+
flattened_poses = poses_array.reshape(poses_array.shape[0], -1)
353+
354+
if n_individuals == 1:
355+
pdindex = pd.MultiIndex.from_product(
356+
[bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"]
357+
)
358+
else:
359+
individuals = [f"individual_{i}" for i in range(n_individuals)]
360+
pdindex = pd.MultiIndex.from_product(
361+
[individuals, bodyparts, ["x", "y", "likelihood"]], names=["individuals", "bodyparts", "coords"]
362+
)
363+
364+
pose_df = pd.DataFrame(flattened_poses, columns=pdindex)
365+
366+
pose_df.to_hdf(h5_save_path, key="df_with_missing", mode="w")
367+
pose_df.to_csv(csv_save_path, index=False)
368+
369+
def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
370+
# Create numpy array with poses:
371+
max_frame = max(p["frame"] for p in poses)
372+
pose_target_shape = (n_individuals, len(bodyparts), 3)
373+
poses_array = np.full((max_frame + 1, *pose_target_shape), np.nan)
374+
375+
for item in poses:
376+
frame = item["frame"]
377+
pose = item["pose"]
378+
if pose.ndim == 2:
379+
pose = pose[np.newaxis, :, :]
380+
padded_pose = np.full(pose_target_shape, np.nan)
381+
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
382+
padded_pose[slices] = pose[slices]
383+
poses_array[frame] = padded_pose
384+
385+
return poses_array
401386

402387

403388
import argparse

0 commit comments

Comments
 (0)