@@ -414,7 +414,7 @@ def __init__(
414414 # Classification model (only if classify=True)
415415 if classify :
416416 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
417- print (f"Using device: { self .device } " )
417+ print (f"Using device: { self .device } " )
418418
419419 if hierarchical_model_path is None :
420420 raise ValueError ("hierarchical_model_path is required when classify=True" )
@@ -432,30 +432,30 @@ def __init__(
432432 raise ValueError ("species_list not found in checkpoint and not provided as argument" )
433433
434434 self .species_list = species_list
435-
436- # Build taxonomy
437- self .taxonomy , self .species_to_genus , self .genus_to_family = get_taxonomy (species_list )
438- self .level_to_idx , self .idx_to_level = create_mappings (self .taxonomy , species_list )
439- self .family_list = sorted (self .taxonomy [1 ])
440- self .genus_list = sorted (self .taxonomy [2 ].keys ())
441-
442- model_backbone = checkpoint .get ("backbone" , backbone )
443- if model_backbone != backbone :
444- print (f"Note: Using backbone '{ model_backbone } ' from checkpoint (overrides '{ backbone } ')" )
445-
446- num_classes = [len (self .family_list ), len (self .genus_list ), len (self .species_list )]
447- print (f"Model architecture: { num_classes } classes per level, backbone: { model_backbone } " )
448-
449- self .model = HierarchicalInsectClassifier (num_classes , backbone = model_backbone )
450- self .model .load_state_dict (state_dict , strict = False )
451- self .model .to (self .device )
452- self .model .eval ()
453-
454- self .transform = transforms .Compose ([
435+
436+ # Build taxonomy
437+ self .taxonomy , self .species_to_genus , self .genus_to_family = get_taxonomy (species_list )
438+ self .level_to_idx , self .idx_to_level = create_mappings (self .taxonomy , species_list )
439+ self .family_list = sorted (self .taxonomy [1 ])
440+ self .genus_list = sorted (self .taxonomy [2 ].keys ())
441+
442+ model_backbone = checkpoint .get ("backbone" , backbone )
443+ if model_backbone != backbone :
444+ print (f"Note: Using backbone '{ model_backbone } ' from checkpoint (overrides '{ backbone } ')" )
445+
446+ num_classes = [len (self .family_list ), len (self .genus_list ), len (self .species_list )]
447+ print (f"Model architecture: { num_classes } classes per level, backbone: { model_backbone } " )
448+
449+ self .model = HierarchicalInsectClassifier (num_classes , backbone = model_backbone )
450+ self .model .load_state_dict (state_dict , strict = False )
451+ self .model .to (self .device )
452+ self .model .eval ()
453+
454+ self .transform = transforms .Compose ([
455455 transforms .Resize ((self .img_size , self .img_size )),
456- transforms .ToTensor (),
456+ transforms .ToTensor (),
457457 transforms .Normalize (mean = IMAGENET_MEAN , std = IMAGENET_STD )
458- ])
458+ ])
459459 else :
460460 print ("Detection-only mode (no classification)" )
461461 self .species_list = []
@@ -982,15 +982,15 @@ def process_video(video_path: str, processor: VideoInferenceProcessor,
982982 # PHASE 3: Classification (or detection-only)
983983 # =========================================================================
984984 if processor .classify :
985- print ("\n " + "=" * 60 )
986- print ("PHASE 3: CLASSIFICATION (Confirmed Tracks Only)" )
987- print ("=" * 60 )
988-
989- if confirmed_track_ids :
990- processor .classify_confirmed_tracks (video_path , confirmed_track_ids , crops_dir = crops_dir )
991- results = processor .hierarchical_aggregation (confirmed_track_ids )
992- else :
993- results = []
985+ print ("\n " + "=" * 60 )
986+ print ("PHASE 3: CLASSIFICATION (Confirmed Tracks Only)" )
987+ print ("=" * 60 )
988+
989+ if confirmed_track_ids :
990+ processor .classify_confirmed_tracks (video_path , confirmed_track_ids , crops_dir = crops_dir )
991+ results = processor .hierarchical_aggregation (confirmed_track_ids )
992+ else :
993+ results = []
994994 else :
995995 if confirmed_track_ids :
996996 results = processor .detection_only_results (confirmed_track_ids )
@@ -1004,24 +1004,24 @@ def process_video(video_path: str, processor: VideoInferenceProcessor,
10041004 has_composites = "track_composites_dir" in output_paths
10051005
10061006 if has_video or has_composites :
1007- print ("\n " + "=" * 60 )
1007+ print ("\n " + "=" * 60 )
10081008 print ("PHASE 4: RENDERING OUTPUT" )
1009- print ("=" * 60 )
1010-
1009+ print ("=" * 60 )
1010+
10111011 if "debug_video" in output_paths :
1012- print (f"\n Rendering debug video (all detections)..." )
1013- _render_debug_video (
1014- video_path , output_paths ["debug_video" ],
1015- processor , confirmed_track_ids , all_track_info , input_fps
1016- )
1017-
1012+ print (f"\n Rendering debug video (all detections)..." )
1013+ _render_debug_video (
1014+ video_path , output_paths ["debug_video" ],
1015+ processor , confirmed_track_ids , all_track_info , input_fps
1016+ )
1017+
10181018 if "annotated_video" in output_paths :
1019- print (f"\n Rendering annotated video ({ len (confirmed_track_ids )} confirmed tracks)..." )
1020- _render_annotated_video (
1021- video_path , output_paths ["annotated_video" ],
1022- processor , confirmed_track_ids , input_fps
1023- )
1024-
1019+ print (f"\n Rendering annotated video ({ len (confirmed_track_ids )} confirmed tracks)..." )
1020+ _render_annotated_video (
1021+ video_path , output_paths ["annotated_video" ],
1022+ processor , confirmed_track_ids , input_fps
1023+ )
1024+
10251025 if has_composites :
10261026 print (f"\n Rendering track composite images..." )
10271027 _render_track_composites (
@@ -1107,12 +1107,12 @@ def _render_annotated_video(video_path: str, output_path: str,
11071107 processor : VideoInferenceProcessor ,
11081108 confirmed_track_ids : Set [str ], fps : float ) -> None :
11091109 """Render annotated video showing only confirmed tracks with classifications."""
1110- cap = cv2 .VideoCapture (video_path )
1111- width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
1112- height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
1113-
1114- out = cv2 .VideoWriter (output_path , cv2 .VideoWriter_fourcc (* 'mp4v' ), fps , (width , height ))
1115-
1110+ cap = cv2 .VideoCapture (video_path )
1111+ width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
1112+ height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
1113+
1114+ out = cv2 .VideoWriter (output_path , cv2 .VideoWriter_fourcc (* 'mp4v' ), fps , (width , height ))
1115+
11161116 if not confirmed_track_ids :
11171117 frame_num = 0
11181118 while True :
@@ -1301,7 +1301,7 @@ def inference(
13011301 print (f"Video: { video_path } " )
13021302 print (f"Mode: { 'Detection + Classification' if classify else 'Detection only' } " )
13031303 if classify :
1304- print (f"Model: { hierarchical_model_path } " )
1304+ print (f"Model: { hierarchical_model_path } " )
13051305 print (f"Output directory: { output_dir } " )
13061306 print ("\n Output files:" )
13071307 for name , path in output_paths .items ():
0 commit comments