Skip to content

Commit c2605de

Browse files
committed
Revamp clips module
1 parent 4ca0830 commit c2605de

File tree

2 files changed

+119
-179
lines changed

2 files changed

+119
-179
lines changed

poseinterface/clips.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Functions to extract clips from poseinterface videos."""
2+
3+
import argparse
4+
import json
5+
import logging
6+
import sys
7+
from pathlib import Path
8+
9+
import sleap_io as sio
10+
11+
12+
def extract_clip(
13+
video_path: str | Path,
14+
start_frame: int,
15+
duration: int,
16+
):
17+
"""Extract clip and clip labels.
18+
19+
We assume:
20+
- the input video filename is in the format
21+
`sub-<subjectID>_ses-<sessionID>_cam-<camID>.mp4`,
22+
- a `sub-<subjectID>_ses-<sessionID>_cam-<camID>_cliplabels.json`
23+
file with tracks for the full video exists alongside the input video,
24+
where the `id` in `images` corresponds to the global video frame 0-based
25+
indices (note that the local frame index and the global frame index is the
26+
same if the data refers to the whole video),
27+
- `start_frame` is 0-based index,
28+
- `duration` is len(clip).
29+
"""
30+
# Read video as array
31+
video_path = Path(video_path)
32+
video = sio.load_video(video_path)
33+
logging.info(
34+
f"filename: {video_path.name}, fps: {video.fps}, shape: {video.shape}"
35+
)
36+
37+
# Slice clip and save as mp4
38+
clip = video[start_frame : start_frame + duration]
39+
clip_path = f"{video.filename}_start-{start_frame}_dur-{duration}.mp4"
40+
sio.save_video(clip, clip_path, fps=video.fps)
41+
42+
# Generate cliplabels.json from the full video labels
43+
clip_json = _extract_cliplabels(video_path, start_frame, duration)
44+
45+
return clip_path, clip_json
46+
47+
48+
def _extract_cliplabels(video_path, start_frame, duration):
49+
"""Extract clip labels from the video cliplabels.json file."""
50+
# Read file with labels for the whole video
51+
video_json = video_path.parent / f"{video_path.stem}_cliplabels.json"
52+
with open(video_json) as f:
53+
video_labels = json.load(f)
54+
55+
# Keep only data from the images in the clip
56+
clip_labels = {}
57+
clip_labels["images"] = [
58+
img
59+
for img in video_labels["images"]
60+
if (img["id"] >= start_frame | img["id"] < start_frame + duration)
61+
]
62+
clip_labels["annotations"] = [
63+
annot
64+
for annot in video_labels["annotations"]
65+
if (
66+
annot["image_id"]
67+
>= start_frame | annot["image_id"]
68+
< start_frame + duration
69+
)
70+
]
71+
clip_labels["categories"] = video_labels["categories"]
72+
73+
# Save json with filtered data
74+
clip_json = (
75+
video_path.parent / f"{video_path.stem}_"
76+
f"start-{start_frame}_dur-{duration}_cliplabels.json"
77+
)
78+
with open(clip_json) as f:
79+
json.dump(clip_labels, f)
80+
81+
return clip_json
82+
83+
84+
def main(args: argparse.Namespace):
85+
# Extract clip
86+
extract_clip(args.video_path, args.start_frame, args.duration)
87+
88+
89+
def parse_args(args) -> argparse.Namespace:
90+
"""Parse command-line arguments."""
91+
parser = argparse.ArgumentParser(description="Extract clips from video")
92+
parser.add_argument(
93+
"--video_path",
94+
type=str,
95+
required=True,
96+
help="Path to video file to clip.",
97+
)
98+
parser.add_argument(
99+
"--start_frame",
100+
type=int,
101+
require=True,
102+
help="Start frame of the clip as a 0-based index.",
103+
)
104+
parser.add_argument(
105+
"--duration",
106+
type=int,
107+
required=True,
108+
help="Total length of the output clip in frames",
109+
)
110+
return parser.parse_args(args)
111+
112+
113+
def wrapper():
114+
args = parse_args(sys.argv[1:])
115+
main(args)
116+
117+
118+
if __name__ == "__main__":
119+
wrapper()

poseinterface/video.py

Lines changed: 0 additions & 179 deletions
This file was deleted.

0 commit comments

Comments
 (0)