Skip to content

Commit 2643c9e

Browse files
committed
v2.0.2: Add --crops flag to save cropped frames per track
1 parent e234ea6 commit 2643c9e

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [2.0.2] - 2025-01-20
8+
9+
### Added
10+
- **Crop export**: New `--crops` flag in inference to save cropped frames for each classified track, organized by track ID
11+
12+
## [2.0.1] - 2025-01-20
13+
14+
### Fixed
15+
- Minor bug fixes and code cleanup
16+
717
## [2.0.0] - 2025-01-06
818

919
### Added

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "bplusplus"
3-
version = "2.0.0"
3+
version = "2.0.2"
44
description = "A simple method to create AI models for biodiversity, with collect and prepare pipeline"
55
authors = ["Titus Venverloo <tvenver@mit.edu>", "Deniz Aydemir <deniz@aydemir.us>", "Orlando Closs <orlandocloss@pm.me>", "Ase Hatveit <aase@mit.edu>"]
66
license = "MIT"

src/bplusplus/inference.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,14 @@ def process_frame(self, frame, frame_time, tracker, frame_number):
463463

464464
return fg_mask, frame_detections
465465

466-
def classify_confirmed_tracks(self, video_path, confirmed_track_ids):
466+
def classify_confirmed_tracks(self, video_path, confirmed_track_ids, crops_dir=None):
467467
"""
468468
Classify only the confirmed tracks by re-reading relevant frames.
469469
470470
Args:
471471
video_path: Path to original video
472472
confirmed_track_ids: Set of track IDs that passed topology analysis
473+
crops_dir: Optional directory to save cropped frames
473474
474475
Returns:
475476
dict: track_id -> list of classifications
@@ -480,6 +481,15 @@ def classify_confirmed_tracks(self, video_path, confirmed_track_ids):
480481

481482
print(f"\nClassifying {len(confirmed_track_ids)} confirmed tracks...")
482483

484+
# Setup crops directory if requested
485+
if crops_dir:
486+
os.makedirs(crops_dir, exist_ok=True)
487+
# Create subdirectory for each track
488+
for track_id in confirmed_track_ids:
489+
track_dir = os.path.join(crops_dir, str(track_id)[:8])
490+
os.makedirs(track_dir, exist_ok=True)
491+
print(f" Saving crops to: {crops_dir}")
492+
483493
# Group detections by frame for confirmed tracks
484494
frames_to_classify = defaultdict(list)
485495
for det in self.all_detections:
@@ -518,11 +528,22 @@ def classify_confirmed_tracks(self, video_path, confirmed_track_ids):
518528
track_classifications[det['track_id']].append(classification)
519529
classified_count += 1
520530

531+
# Save crop if requested
532+
if crops_dir:
533+
track_id = det['track_id']
534+
track_dir = os.path.join(crops_dir, str(track_id)[:8])
535+
crop = frame[int(y1):int(y2), int(x1):int(x2)]
536+
if crop.size > 0:
537+
crop_path = os.path.join(track_dir, f"frame_{target_frame:06d}.jpg")
538+
cv2.imwrite(crop_path, crop)
539+
521540
if classified_count % 20 == 0:
522541
print(f" Classified {classified_count} detections...", end='\r')
523542

524543
cap.release()
525544
print(f"\n✓ Classified {classified_count} detections from {len(confirmed_track_ids)} tracks")
545+
if crops_dir:
546+
print(f"✓ Saved {classified_count} crops to {crops_dir}")
526547

527548
return track_classifications
528549

@@ -736,7 +757,7 @@ def save_results(self, results, output_paths):
736757
# VIDEO PROCESSING
737758
# ============================================================================
738759

739-
def process_video(video_path, processor, output_paths, show_video=False, fps=None):
760+
def process_video(video_path, processor, output_paths, show_video=False, fps=None, crops_dir=None):
740761
"""
741762
Process video file with efficient classification (confirmed tracks only).
742763
@@ -752,6 +773,7 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
752773
output_paths: Dict with output file paths
753774
show_video: Display video while processing
754775
fps: Target FPS (skip frames if lower than input)
776+
crops_dir: Optional directory to save cropped frames for each track
755777
756778
Returns:
757779
list: Aggregated results
@@ -832,7 +854,7 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
832854
print("="*60)
833855

834856
if confirmed_track_ids:
835-
processor.classify_confirmed_tracks(video_path, confirmed_track_ids)
857+
processor.classify_confirmed_tracks(video_path, confirmed_track_ids, crops_dir=crops_dir)
836858
results = processor.hierarchical_aggregation(confirmed_track_ids)
837859
else:
838860
results = []
@@ -1050,6 +1072,7 @@ def inference(
10501072
fps=None,
10511073
config=None,
10521074
backbone="resnet50",
1075+
crops=False,
10531076
):
10541077
"""
10551078
Run inference on a video file.
@@ -1066,6 +1089,7 @@ def inference(
10661089
- dict: config parameters directly
10671090
backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101').
10681091
If model checkpoint contains backbone info, it will be used instead.
1092+
crops: If True, save cropped frames for each classified track
10691093
10701094
Returns:
10711095
dict: Processing results with output file paths
@@ -1075,6 +1099,7 @@ def inference(
10751099
- {video_name}_debug.mp4: Side-by-side with GMM motion mask
10761100
- {video_name}_results.csv: Aggregated track results
10771101
- {video_name}_detections.csv: Frame-by-frame detections
1102+
- {video_name}_crops/ (if crops=True): Directory with cropped frames per track
10781103
"""
10791104
if not os.path.exists(video_path):
10801105
print(f"Error: Video not found: {video_path}")
@@ -1106,6 +1131,11 @@ def inference(
11061131
"detections_csv": os.path.join(output_dir, f"{video_name}_detections.csv"),
11071132
}
11081133

1134+
# Setup crops directory if requested
1135+
crops_dir = os.path.join(output_dir, f"{video_name}_crops") if crops else None
1136+
if crops_dir:
1137+
output_paths["crops_dir"] = crops_dir
1138+
11091139
print("\n" + "="*60)
11101140
print("BPLUSPLUS INFERENCE")
11111141
print("="*60)
@@ -1133,7 +1163,8 @@ def inference(
11331163
video_path=video_path,
11341164
processor=processor,
11351165
output_paths=output_paths,
1136-
fps=fps
1166+
fps=fps,
1167+
crops_dir=crops_dir
11371168
)
11381169

11391170
return {
@@ -1174,6 +1205,7 @@ def main():
11741205
- {video_name}_debug.mp4: Side-by-side view with GMM motion mask
11751206
- {video_name}_results.csv: Aggregated track results
11761207
- {video_name}_detections.csv: Frame-by-frame detections
1208+
- {video_name}_crops/ (with --crops): Cropped frames for each track
11771209
"""
11781210
)
11791211

@@ -1192,6 +1224,8 @@ def main():
11921224
parser.add_argument('--backbone', '-b', default='resnet50',
11931225
choices=['resnet18', 'resnet50', 'resnet101'],
11941226
help='ResNet backbone (default: resnet50, overridden by checkpoint if saved)')
1227+
parser.add_argument('--crops', action='store_true',
1228+
help='Save cropped frames for each classified track')
11951229

11961230
# Detection parameters (override config)
11971231
defaults = DEFAULT_DETECTION_CONFIG
@@ -1267,6 +1301,7 @@ def main():
12671301
fps=args.fps,
12681302
config=config,
12691303
backbone=args.backbone,
1304+
crops=args.crops,
12701305
)
12711306

12721307
if result.get("success"):

0 commit comments

Comments
 (0)