@@ -463,13 +463,14 @@ def process_frame(self, frame, frame_time, tracker, frame_number):
463463
464464 return fg_mask , frame_detections
465465
466- def classify_confirmed_tracks (self , video_path , confirmed_track_ids ):
466+ def classify_confirmed_tracks (self , video_path , confirmed_track_ids , crops_dir = None ):
467467 """
468468 Classify only the confirmed tracks by re-reading relevant frames.
469469
470470 Args:
471471 video_path: Path to original video
472472 confirmed_track_ids: Set of track IDs that passed topology analysis
473+ crops_dir: Optional directory to save cropped frames
473474
474475 Returns:
475476 dict: track_id -> list of classifications
@@ -480,6 +481,15 @@ def classify_confirmed_tracks(self, video_path, confirmed_track_ids):
480481
481482 print (f"\n Classifying { len (confirmed_track_ids )} confirmed tracks..." )
482483
484+ # Setup crops directory if requested
485+ if crops_dir :
486+ os .makedirs (crops_dir , exist_ok = True )
487+ # Create subdirectory for each track
488+ for track_id in confirmed_track_ids :
489+ track_dir = os .path .join (crops_dir , str (track_id )[:8 ])
490+ os .makedirs (track_dir , exist_ok = True )
491+ print (f" Saving crops to: { crops_dir } " )
492+
483493 # Group detections by frame for confirmed tracks
484494 frames_to_classify = defaultdict (list )
485495 for det in self .all_detections :
@@ -518,11 +528,22 @@ def classify_confirmed_tracks(self, video_path, confirmed_track_ids):
518528 track_classifications [det ['track_id' ]].append (classification )
519529 classified_count += 1
520530
531+ # Save crop if requested
532+ if crops_dir :
533+ track_id = det ['track_id' ]
534+ track_dir = os .path .join (crops_dir , str (track_id )[:8 ])
535+ crop = frame [int (y1 ):int (y2 ), int (x1 ):int (x2 )]
536+ if crop .size > 0 :
537+ crop_path = os .path .join (track_dir , f"frame_{ target_frame :06d} .jpg" )
538+ cv2 .imwrite (crop_path , crop )
539+
521540 if classified_count % 20 == 0 :
522541 print (f" Classified { classified_count } detections..." , end = '\r ' )
523542
524543 cap .release ()
525544 print (f"\n ✓ Classified { classified_count } detections from { len (confirmed_track_ids )} tracks" )
545+ if crops_dir :
546+ print (f"✓ Saved { classified_count } crops to { crops_dir } " )
526547
527548 return track_classifications
528549
@@ -736,7 +757,7 @@ def save_results(self, results, output_paths):
736757# VIDEO PROCESSING
737758# ============================================================================
738759
739- def process_video (video_path , processor , output_paths , show_video = False , fps = None ):
760+ def process_video (video_path , processor , output_paths , show_video = False , fps = None , crops_dir = None ):
740761 """
741762 Process video file with efficient classification (confirmed tracks only).
742763
@@ -752,6 +773,7 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
752773 output_paths: Dict with output file paths
753774 show_video: Display video while processing
754775 fps: Target FPS (skip frames if lower than input)
776+ crops_dir: Optional directory to save cropped frames for each track
755777
756778 Returns:
757779 list: Aggregated results
@@ -832,7 +854,7 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
832854 print ("=" * 60 )
833855
834856 if confirmed_track_ids :
835- processor .classify_confirmed_tracks (video_path , confirmed_track_ids )
857+ processor .classify_confirmed_tracks (video_path , confirmed_track_ids , crops_dir = crops_dir )
836858 results = processor .hierarchical_aggregation (confirmed_track_ids )
837859 else :
838860 results = []
@@ -1050,6 +1072,7 @@ def inference(
10501072 fps = None ,
10511073 config = None ,
10521074 backbone = "resnet50" ,
1075+ crops = False ,
10531076):
10541077 """
10551078 Run inference on a video file.
@@ -1066,6 +1089,7 @@ def inference(
10661089 - dict: config parameters directly
10671090 backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101').
10681091 If model checkpoint contains backbone info, it will be used instead.
1092+ crops: If True, save cropped frames for each classified track
10691093
10701094 Returns:
10711095 dict: Processing results with output file paths
@@ -1075,6 +1099,7 @@ def inference(
10751099 - {video_name}_debug.mp4: Side-by-side with GMM motion mask
10761100 - {video_name}_results.csv: Aggregated track results
10771101 - {video_name}_detections.csv: Frame-by-frame detections
1102+ - {video_name}_crops/ (if crops=True): Directory with cropped frames per track
10781103 """
10791104 if not os .path .exists (video_path ):
10801105 print (f"Error: Video not found: { video_path } " )
@@ -1106,6 +1131,11 @@ def inference(
11061131 "detections_csv" : os .path .join (output_dir , f"{ video_name } _detections.csv" ),
11071132 }
11081133
1134+ # Setup crops directory if requested
1135+ crops_dir = os .path .join (output_dir , f"{ video_name } _crops" ) if crops else None
1136+ if crops_dir :
1137+ output_paths ["crops_dir" ] = crops_dir
1138+
11091139 print ("\n " + "=" * 60 )
11101140 print ("BPLUSPLUS INFERENCE" )
11111141 print ("=" * 60 )
@@ -1133,7 +1163,8 @@ def inference(
11331163 video_path = video_path ,
11341164 processor = processor ,
11351165 output_paths = output_paths ,
1136- fps = fps
1166+ fps = fps ,
1167+ crops_dir = crops_dir
11371168 )
11381169
11391170 return {
@@ -1174,6 +1205,7 @@ def main():
11741205 - {video_name}_debug.mp4: Side-by-side view with GMM motion mask
11751206 - {video_name}_results.csv: Aggregated track results
11761207 - {video_name}_detections.csv: Frame-by-frame detections
1208+ - {video_name}_crops/ (with --crops): Cropped frames for each track
11771209 """
11781210 )
11791211
@@ -1192,6 +1224,8 @@ def main():
11921224 parser .add_argument ('--backbone' , '-b' , default = 'resnet50' ,
11931225 choices = ['resnet18' , 'resnet50' , 'resnet101' ],
11941226 help = 'ResNet backbone (default: resnet50, overridden by checkpoint if saved)' )
1227+ parser .add_argument ('--crops' , action = 'store_true' ,
1228+ help = 'Save cropped frames for each classified track' )
11951229
11961230 # Detection parameters (override config)
11971231 defaults = DEFAULT_DETECTION_CONFIG
@@ -1267,6 +1301,7 @@ def main():
12671301 fps = args .fps ,
12681302 config = config ,
12691303 backbone = args .backbone ,
1304+ crops = args .crops ,
12701305 )
12711306
12721307 if result .get ("success" ):
0 commit comments