Skip to content

Commit ce28eeb

Browse files
committed
feat: overlay command
1 parent f4f48d7 commit ce28eeb

File tree

11 files changed

+325
-76
lines changed

11 files changed

+325
-76
lines changed

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Changelog
2+
3+
All notable changes to this project will be documented in this file.
4+
5+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7+
8+
## [1.1.0] - 2025-10-14
9+
10+
### Added
11+
12+
- New `overlay` command to render a video of the cumulated tracked poses.
13+
14+
## [1.0.0]
15+
16+
### Added
17+
18+
- Initial release

Makefile

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
smoke-test:
2+
echo "smoke testing track feature"
3+
uv run choreopath track examples/fall-recovery-4.mp4 test-data/test.csv
4+
echo "smoke testing draw feature"
5+
uv run choreopath draw test-data/test.csv test-data/test.svg
6+
echo "smoke testing analyze feature"
7+
uv run choreopath analyze test-data/test.csv
8+
echo "smoke testing overlay feature"
9+
uv run choreopath overlay examples/fall-recovery-4.mp4 test-data/overlay.mp4
10+
echo "smoke testing overlay feature"
11+
uv run choreopath overlay --paths-only examples/fall-recovery-4.mp4 test-data/overlay-paths-only.mp4

README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ choreopath draw tracking_data.csv output.svg
4848
choreopath draw tracking_data.csv output.svg --width 1920 --height 1080 --min-visibility 0.7
4949
```
5050

51+
### Generate an overlay video
52+
53+
```bash
54+
choreopath overlay video.mp4 video-overlay.mp4
55+
```
56+
5157
### Analyze tracking data
5258

5359
```bash
@@ -62,9 +68,19 @@ generated videos of dancers to create SVG that I could then plot using a pen plo
6268
See example files:
6369

6470
[original video](examples/fall-recovery-4.mp4)
71+
[overlay video](examples/fall-recovery-4-overlay.mp4)
6572
[animated tracking data](examples/fall-recovery-4-animation.mp4)
66-
![examples/fall-recovery-4.svg](https://github.com/marcw/choreopath/blob/main/examples/fall-recovery-4.svg)
73+
[examples/fall-recovery-4.svg](https://github.com/marcw/choreopath/blob/main/examples/fall-recovery-4.svg)
74+
75+
## Development
76+
77+
- Use `uv`
78+
- Run smoke test by using `make smoke-test`
79+
80+
## Changelog
81+
82+
See [CHANGELOG.md](CHANGELOG.md).
6783

6884
## License
6985

70-
This software is under a MIT license. Please see [LICENSE.md](LICENSE.md)
86+
This software is under a MIT license. See [LICENSE.md](LICENSE.md).
934 KB
Binary file not shown.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "choreopath"
3-
version = "1.0.1"
3+
version = "1.1.0"
44
description = "Transform human movement into generative art. Track body poses from video and create SVG visualizations of motion trajectories"
55
readme = "README.md"
66
requires-python = ">=3.12,<3.13"

src/choreopath/cli.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .video import Video
55
from .svg_generator import SVGGenerator
66
from .tracking_data import TrackingData
7+
from .video_overlay_renderer import VideoOverlayRenderer
78

89
@click.group()
910
def cli():
@@ -15,13 +16,13 @@ def cli():
1516
@click.option("--min-detection-confidence", type=float, default=0.5)
1617
@click.option("--min-tracking-confidence", type=float, default=0.5)
1718
def track(src, dst, min_detection_confidence, min_tracking_confidence):
18-
video = Video(src)
19+
video = Video(src, min_detection_confidence, min_tracking_confidence)
1920
click.echo(f"Tracking poses in {src}")
2021
click.echo("Found {} frames".format(video.total_frames()))
2122
click.echo("FPS: {}".format(video.fps()))
2223
click.echo("Tracking poses with min detection confidence: {} and min tracking confidence: {}".format(min_detection_confidence, min_tracking_confidence))
2324

24-
tracking_data = video.track_poses(min_detection_confidence, min_tracking_confidence)
25+
tracking_data = video.track_poses()
2526

2627
if tracking_data:
2728
df = pd.DataFrame(tracking_data)
@@ -30,6 +31,8 @@ def track(src, dst, min_detection_confidence, min_tracking_confidence):
3031
click.echo(f"Total data points: {len(tracking_data)}")
3132
else:
3233
click.echo("No tracking data found. Check if there are people visible in the video.")
34+
35+
video.close()
3336

3437
@cli.command(help='Generate SVG trajectories from body tracking data')
3538
@click.argument('src', type=click.Path(exists=True, readable=True, dir_okay=False))
@@ -75,3 +78,32 @@ def analyze(src, animation, fps):
7578

7679
click.echo(f"\nGenerating tracking data animation to: {animation}")
7780
tracking_data.to_animation(animation, fps)
81+
82+
@cli.command(help='Generate video with progressive pose path overlays')
83+
@click.argument('video', type=click.Path(exists=True, readable=True, dir_okay=False))
84+
@click.argument('output', type=click.Path(writable=True, dir_okay=False))
85+
@click.option('--min-detection-confidence', type=float, default=0.5, help='Minimum detection confidence')
86+
@click.option('--min-tracking-confidence', type=float, default=0.5, help='Minimum tracking confidence')
87+
@click.option('--min-visibility', type=float, default=0.5, help='Minimum visibility threshold')
88+
@click.option('--line-thickness', type=int, default=2, help='Path line thickness in pixels')
89+
@click.option('--no-current-point', is_flag=True, help='Disable current position marker')
90+
@click.option('--paths-only', is_flag=True, help='Render only paths')
91+
def overlay(video, output, min_detection_confidence, min_tracking_confidence, min_visibility, line_thickness, no_current_point, paths_only):
92+
click.echo(f"Generating video overlay from {video} to {output}")
93+
if paths_only:
94+
click.echo("Mode: Paths only (black background)")
95+
96+
video = Video(video, min_detection_confidence, min_tracking_confidence)
97+
98+
renderer = VideoOverlayRenderer(
99+
line_thickness=line_thickness,
100+
show_current_point=not no_current_point,
101+
min_visibility=min_visibility,
102+
paths_only=paths_only,
103+
)
104+
105+
renderer.render_overlay(video=video, output_path=output)
106+
107+
video.close()
108+
109+
click.echo(f"Video overlay saved to: {output}")

src/choreopath/colors.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Tuple
2+
3+
DEFAULT_PALETTE = {
4+
'face': '#f87171',
5+
'left_arm': '#fb923c',
6+
'right_arm': '#facc15',
7+
'hips': '#71717a',
8+
'left_leg': '#06b6d4',
9+
'right_leg': '#3b82f6'
10+
}
11+
12+
class Palette:
13+
def __init__(self):
14+
self.body_colors = DEFAULT_PALETTE
15+
16+
def get_body_region_color(self, body_region: str) -> str:
17+
"""Get color for a specific body region."""
18+
return self.body_colors[body_region]
19+
20+
def get_landmark_color(self, landmark_id: int) -> str:
21+
"""Get color for a specific landmark based on body region."""
22+
if landmark_id <= 10:
23+
return self.body_colors['face']
24+
elif landmark_id in [11, 13, 15, 17, 19, 21]:
25+
return self.body_colors['left_arm']
26+
elif landmark_id in [12, 14, 16, 18, 20, 22]:
27+
return self.body_colors['right_arm']
28+
elif landmark_id in [23, 24]:
29+
return self.body_colors['hips']
30+
elif landmark_id in [25, 27, 29, 31]:
31+
return self.body_colors['left_leg']
32+
elif landmark_id in [26, 28, 30, 32]:
33+
return self.body_colors['right_leg']
34+
else:
35+
return '#888888' # Gray fallback
36+
37+
def get_landmark_color_bgr(self, landmark_id: int) -> Tuple[int, int, int]:
38+
"""Get BGR color tuple for landmark based on body region."""
39+
return self.hex_to_bgr(self.get_landmark_color(landmark_id))
40+
41+
def hex_to_bgr(self, hex_color: str) -> Tuple[int, int, int]:
42+
"""Convert hex color (#RRGGBB) to BGR tuple for OpenCV."""
43+
hex_color = hex_color.lstrip('#')
44+
r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
45+
return (b, g, r)

src/choreopath/svg_generator.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, Tuple
44
from .tracking_data import TrackingData
55
import mediapipe as mp
6+
from .colors import Palette
67

78
class SVGGenerator:
89
"""Generates SVG trajectories from body tracking data."""
@@ -13,14 +14,7 @@ def __init__(self, width: int = 1280, height: int = 720, show_legend: bool = Tru
1314
self.show_legend = show_legend
1415

1516
# Color scheme for different body regions
16-
self.body_colors = {
17-
'face': '#f87171',
18-
'left_arm': '#fb923c',
19-
'right_arm': '#facc15',
20-
'hips': '#71717a',
21-
'left_leg': '#06b6d4',
22-
'right_leg': '#3b82f6'
23-
}
17+
self.palette = Palette()
2418

2519
# Hierarchical body structure mapping
2620
self.body_hierarchy = {
@@ -66,23 +60,6 @@ def __init__(self, width: int = 1280, height: int = 720, show_legend: bool = Tru
6660
32: ["body", "right leg"]
6761
}
6862

69-
def get_landmark_color(self, landmark_id: int) -> str:
70-
"""Get color for a specific landmark based on body region."""
71-
if landmark_id <= 10:
72-
return self.body_colors['face']
73-
elif landmark_id in [11, 13, 15, 17, 19, 21]:
74-
return self.body_colors['left_arm']
75-
elif landmark_id in [12, 14, 16, 18, 20, 22]:
76-
return self.body_colors['right_arm']
77-
elif landmark_id in [23, 24]:
78-
return self.body_colors['hips']
79-
elif landmark_id in [25, 27, 29, 31]:
80-
return self.body_colors['left_leg']
81-
elif landmark_id in [26, 28, 30, 32]:
82-
return self.body_colors['right_leg']
83-
else:
84-
return '#888888' # Gray fallback
85-
8663
def normalize_to_svg_coords(self, x: float, y: float) -> Tuple[float, float]:
8764
"""
8865
Convert normalized coordinates (0-1) to SVG coordinates.
@@ -187,7 +164,7 @@ def generate(self, tracking_data: TrackingData) -> ET.ElementTree:
187164
path_data += f" L {x:.2f} {y:.2f}"
188165

189166
path_elem.set('d', path_data)
190-
path_elem.set('stroke', self.get_landmark_color(landmark_id))
167+
path_elem.set('stroke', self.palette.get_landmark_color(landmark_id))
191168
path_elem.set('stroke-width', '1')
192169
path_elem.set('fill', 'none')
193170
path_elem.set('opacity', '0.7')
@@ -230,12 +207,12 @@ def add_legend(self, svg_root: ET.Element) -> None:
230207

231208
# Legend entries
232209
legend_items = [
233-
('Face', self.body_colors['face']),
234-
('Left Arm', self.body_colors['left_arm']),
235-
('Right Arm', self.body_colors['right_arm']),
236-
('Hips', self.body_colors['hips']),
237-
('Left Leg', self.body_colors['left_leg']),
238-
('Right Leg', self.body_colors['right_leg'])
210+
('Face', self.palette.get_body_region_color('face')),
211+
('Left Arm', self.palette.get_body_region_color('left_arm')),
212+
('Right Arm', self.palette.get_body_region_color('right_arm')),
213+
('Hips', self.palette.get_body_region_color('hips')),
214+
('Left Leg', self.palette.get_body_region_color('left_leg')),
215+
('Right Leg', self.palette.get_body_region_color('right_leg'))
239216
]
240217

241218
for i, (label, color) in enumerate(legend_items):

src/choreopath/video.py

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,89 @@
1+
from pydoc import doc
12
import cv2
23
import mediapipe as mp
3-
from typing import List, Dict
4+
import numpy as np
5+
from typing import List, Dict, Tuple
46

57
class Video:
6-
def __init__(self, path: str):
8+
def __init__(self, path: str, min_detection_confidence: float = 0.5, min_tracking_confidence: float = 0.5):
79
"""
810
Initialize the video object and open the video file for reading.
911
"""
1012
self.path = path
1113
self.cap = cv2.VideoCapture(self.path)
14+
self.min_detection_confidence = min_detection_confidence
15+
self.min_tracking_confidence = min_tracking_confidence
1216
if not self.cap.isOpened():
1317
raise ValueError(f"Error opening video file: {self.path}")
1418

15-
def track_poses(self, min_detection_confidence: float = 0.5, min_tracking_confidence: float = 0.5) -> List[Dict]:
19+
self.pose = mp.solutions.pose.Pose(min_detection_confidence=self.min_detection_confidence, min_tracking_confidence=self.min_tracking_confidence)
20+
self.pose_frame_count = 0
21+
22+
def next_pose(self) -> Tuple[bool, np.ndarray, Dict]:
23+
if not self.cap.isOpened():
24+
return False, None, []
25+
26+
ret, frame = self.cap.read()
27+
if not ret:
28+
return False, None, []
29+
30+
frame_tracking_data = []
31+
timestamp = self.pose_frame_count / self.fps()
32+
33+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
34+
results = self.pose.process(rgb_frame)
35+
36+
if results.pose_landmarks:
37+
for idx, landmark in enumerate(results.pose_landmarks.landmark):
38+
landmark_name = mp.solutions.pose.PoseLandmark(idx).name
39+
frame_tracking_data.append({
40+
'frame': self.pose_frame_count,
41+
'timestamp': timestamp,
42+
'landmark_name': landmark_name,
43+
'landmark_id': idx,
44+
'x': landmark.x,
45+
'y': landmark.y,
46+
'z': landmark.z,
47+
'visibility': landmark.visibility
48+
})
49+
50+
self.pose_frame_count += 1
51+
52+
return True, frame, frame_tracking_data
53+
54+
def close(self):
55+
self.pose.close()
56+
self.cap.release()
57+
58+
def track_poses(self) -> List[Dict]:
1659
"""
1760
Track poses in the video and return tracking data.
1861
"""
19-
pose = mp.solutions.pose.Pose(min_detection_confidence=min_detection_confidence, min_tracking_confidence=min_tracking_confidence)
20-
_frame_count = 0
21-
_fps = self.fps()
22-
_total_frames = self.total_frames()
23-
2462
tracking_data = []
2563

26-
while self.cap.isOpened():
27-
ret, frame = self.cap.read()
28-
if not ret:
64+
while True:
65+
ok, frame, frame_tracking_data = self.next_pose()
66+
if not ok:
2967
break
30-
31-
# Convert BGR to RGB
32-
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33-
34-
# Process frame
35-
results = pose.process(rgb_frame)
36-
37-
_timestamp = _frame_count / _fps
38-
39-
# Extract landmarks
40-
if results.pose_landmarks:
41-
for idx, landmark in enumerate(results.pose_landmarks.landmark):
42-
landmark_name = mp.solutions.pose.PoseLandmark(idx).name
43-
tracking_data.append({
44-
'frame': _frame_count,
45-
'timestamp': _timestamp,
46-
'landmark_name': landmark_name,
47-
'landmark_id': idx,
48-
'x': landmark.x,
49-
'y': landmark.y,
50-
'z': landmark.z,
51-
'visibility': landmark.visibility
52-
})
53-
54-
_frame_count += 1
55-
56-
self.cap.release()
5768

69+
for frame in frame_tracking_data:
70+
tracking_data.append(frame)
71+
5872
return tracking_data
5973

74+
75+
def width(self) -> int:
76+
"""
77+
Returns the width of the video.
78+
"""
79+
return int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
80+
81+
def height(self) -> int:
82+
"""
83+
Returns the height of the video.
84+
"""
85+
return int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
86+
6087
def fps(self) -> float:
6188
"""
6289
Returns the frames per second of the video.

0 commit comments

Comments
 (0)