2626from boxmot .utils .timing import TimingStats , wrap_tracker_reid
2727from typing import Optional , List , Dict , Generator
2828
29- from boxmot .utils .dataloaders .MOT17 import MOT17DetEmbDataset
29+ from boxmot .utils .dataloaders .dataset import MOTDataset
3030from boxmot .postprocessing .gsi import gsi
3131
3232from boxmot .engine .inference import DetectorReIDPipeline , extract_detections , filter_detections
@@ -343,8 +343,13 @@ def generate_dets_embs_batched(args: argparse.Namespace, y: Path, source_root: P
343343 det_fhs = {}
344344 emb_fhs = {r .stem : {} for r in args .reid_model }
345345
346- dets_folder = Path (args .project ) / 'dets_n_embs' / y .stem / 'dets'
347- embs_root = Path (args .project ) / 'dets_n_embs' / y .stem / 'embs'
346+ # runs/dets_n_embs/<dataset_name>/y.stem/... when benchmark is set
347+ benchmark = getattr (args , "benchmark" , None )
348+ dets_base = Path (args .project ) / "dets_n_embs"
349+ if benchmark :
350+ dets_base = dets_base / benchmark
351+ dets_folder = dets_base / y .stem / "dets"
352+ embs_root = dets_base / y .stem / "embs"
348353 total_frames = 0
349354 initial_done = 0
350355
@@ -594,27 +599,27 @@ def generate_dets_embs_batched(args: argparse.Namespace, y: Path, source_root: P
594599 pass
595600
596601
597- def run_generate_dets_embs (opt : argparse .Namespace , timing_stats : Optional [TimingStats ] = None ) -> None :
602+ def run_generate_dets_embs (args : argparse .Namespace , timing_stats : Optional [TimingStats ] = None ) -> None :
598603 """
599604 Generate detections and embeddings for all sequences.
600605
601606 Args:
602- opt : CLI arguments.
607+ args : CLI arguments.
603608 timing_stats: Optional TimingStats for timing instrumentation.
604609 """
605- source_root = Path (opt .source )
610+ source_root = Path (args .source )
606611
607- opt .batch_size = int (getattr (opt , "batch_size" , 16 ))
608- if getattr (opt , "read_threads" , None ) is None :
609- opt .read_threads = min (8 , (os .cpu_count () or 8 ))
610- if not hasattr (opt , "auto_batch" ):
611- opt .auto_batch = True
612- if not hasattr (opt , "resume" ):
613- opt .resume = True
612+ args .batch_size = int (getattr (args , "batch_size" , 16 ))
613+ if getattr (args , "read_threads" , None ) is None :
614+ args .read_threads = min (8 , (os .cpu_count () or 8 ))
615+ if not hasattr (args , "auto_batch" ):
616+ args .auto_batch = True
617+ if not hasattr (args , "resume" ):
618+ args .resume = True
614619
615- for y in opt .yolo_model :
620+ for y in args .yolo_model :
616621 LOGGER .info (f"Generating dets+embs (batched single-process): { y .name } " )
617- generate_dets_embs_batched (opt , y , source_root , timing_stats = timing_stats )
622+ generate_dets_embs_batched (args , y , source_root , timing_stats = timing_stats )
618623
619624def build_dataset_eval_settings (
620625 args : argparse .Namespace ,
@@ -640,18 +645,31 @@ def build_dataset_eval_settings(
640645 eval_classes_cfg = bench_cfg .get ("eval_classes" ) if isinstance (bench_cfg , dict ) else None
641646 distractor_cfg = bench_cfg .get ("distractor_classes" ) if isinstance (bench_cfg , dict ) else None
642647
643- # Classes and ids
644- classes_to_eval = ["person" ]
645- class_ids = [1 ]
646- if isinstance (eval_classes_cfg , dict ) and len (eval_classes_cfg ) > 0 :
647- ordered = sorted (((int (k ), v ) for k , v in eval_classes_cfg .items ()), key = lambda kv : kv [0 ])
648- class_ids = [k for k , _ in ordered ]
649- classes_to_eval = [v for _ , v in ordered ]
650- elif hasattr (args , "classes" ) and args .classes is not None :
648+ classes_to_eval = []
649+ class_ids = []
650+
651+ # Filter classes by user provided classes
652+ if hasattr (args , "classes" ) and args .classes is not None :
651653 class_indices = args .classes if isinstance (args .classes , list ) else [args .classes ]
652654 classes_to_eval = [COCO_CLASSES [int (i )] for i in class_indices ]
653655 class_ids = [int (i ) + 1 for i in class_indices ]
654656
657+ # Match classes by benchmark config
658+ if isinstance (eval_classes_cfg , dict ) and len (eval_classes_cfg ) > 0 :
659+ ordered = sorted (((int (k ), v ) for k , v in eval_classes_cfg .items ()), key = lambda kv : kv [0 ])
660+ if class_ids :
661+ class_ids = [k for k , _ in ordered if class_ids and k in class_ids ]
662+ classes_to_eval = [v for k , v in ordered if class_ids and k in class_ids ]
663+ else :
664+ class_ids = [k for k , _ in ordered ]
665+ classes_to_eval = [v for k , v in ordered ]
666+
667+ # Default classes
668+ if not classes_to_eval :
669+ classes_to_eval = ["person" ]
670+ if not class_ids :
671+ class_ids = [1 ]
672+
655673 # Distractors
656674 distractor_ids : list [int ] = []
657675 if isinstance (distractor_cfg , dict ) and len (distractor_cfg ) > 0 :
@@ -767,6 +785,7 @@ def process_sequence(seq_name: str,
767785 target_fps : Optional [int ],
768786 device : str ,
769787 cfg_dict : Optional [Dict ] = None ,
788+ dataset_name : Optional [str ] = None ,
770789 ):
771790 """
772791 Process a single sequence: run tracker on pre-computed detections/embeddings.
@@ -789,9 +808,13 @@ def process_sequence(seq_name: str,
789808 )
790809
791810 # load with the user’s FPS
792- dataset = MOT17DetEmbDataset (
811+ # runs/dets_n_embs/<dataset_name>/ when dataset_name is set
812+ det_emb_root = Path (project_root ) / "dets_n_embs"
813+ if dataset_name :
814+ det_emb_root = det_emb_root / dataset_name
815+ dataset = MOTDataset (
793816 mot_root = mot_root ,
794- det_emb_root = str (Path ( project_root ) / 'dets_n_embs' ),
817+ det_emb_root = str (det_emb_root ),
795818 model_name = model_name ,
796819 reid_name = reid_name ,
797820 target_fps = target_fps
@@ -845,44 +868,49 @@ def _worker_init():
845868 # each spawned process needs its own sinks
846869 _configure_logging ()
847870
848- def run_generate_mot_results (opt : argparse .Namespace , evolve_config : dict = None , timing_stats : Optional [TimingStats ] = None ) -> None :
871+ def run_generate_mot_results (args : argparse .Namespace , evolve_config : dict = None , timing_stats : Optional [TimingStats ] = None ) -> None :
849872 """
850873 Run tracker on pre-computed detections/embeddings and generate MOT result files.
851874
852875 Args:
853- opt : CLI arguments.
876+ args : CLI arguments.
854877 evolve_config: Optional config dict for hyperparameter tuning.
855878 timing_stats: Optional TimingStats to record tracking/association time.
856879 """
857- # Prepare experiment folder
858- base = opt .project / 'mot' / f"{ opt .yolo_model [0 ].stem } _{ opt .reid_model [0 ].stem } _{ opt .tracking_method } "
880+ # Prepare experiment folder: runs/mot/<dataset_name>/model_reid_tracker when benchmark is set
881+ base = args .project / "mot"
882+ if getattr (args , "benchmark" , None ):
883+ base = base / args .benchmark
884+ base = base / f"{ args .yolo_model [0 ].stem } _{ args .reid_model [0 ].stem } _{ args .tracking_method } "
859885 exp_dir = increment_path (base , sep = "_" , exist_ok = False )
860886 exp_dir .mkdir (parents = True , exist_ok = True )
861- opt .exp_dir = exp_dir
887+ args .exp_dir = exp_dir
862888
863889 # Just collect sequence names by scanning directory names
864890 sequence_names = []
865- for d in Path (opt .source ).iterdir ():
891+ for d in Path (args .source ).iterdir ():
866892 if not d .is_dir ():
867893 continue
868894 img_dir = d / "img1" if (d / "img1" ).exists () else d
869895 if any (img_dir .glob ("*.jpg" )) or any (img_dir .glob ("*.png" )):
870896 sequence_names .append (d .name )
871897 sequence_names .sort ()
872898
873- # Build task arguments
899+ # Build task arguments (include dataset_name for det_emb_root path)
900+ dataset_name = getattr (args , "benchmark" , None )
874901 task_args = [
875902 (
876903 seq ,
877- str (opt .source ),
878- str (opt .project ),
879- opt .yolo_model [0 ].stem ,
880- opt .reid_model [0 ].stem ,
881- opt .tracking_method ,
904+ str (args .source ),
905+ str (args .project ),
906+ args .yolo_model [0 ].stem ,
907+ args .reid_model [0 ].stem ,
908+ args .tracking_method ,
882909 str (exp_dir ),
883- getattr (opt , ' fps' , None ),
884- opt .device ,
910+ getattr (args , " fps" , None ),
911+ args .device ,
885912 evolve_config ,
913+ dataset_name ,
886914 )
887915 for seq in sequence_names
888916 ]
@@ -923,31 +951,31 @@ def run_generate_mot_results(opt: argparse.Namespace, evolve_config: dict = None
923951 )
924952
925953 # Optional GSI postprocessing
926- if getattr (opt , "postprocessing" , "none" ) == "gsi" :
954+ if getattr (args , "postprocessing" , "none" ) == "gsi" :
927955 LOGGER .opt (colors = True ).info ("<cyan>[3b/4]</cyan> Applying GSI postprocessing..." )
928956 from boxmot .postprocessing .gsi import gsi
929957 gsi (mot_results_folder = exp_dir )
930958
931- elif getattr (opt , "postprocessing" , "none" ) == "gbrc" :
959+ elif getattr (args , "postprocessing" , "none" ) == "gbrc" :
932960 LOGGER .opt (colors = True ).info ("<cyan>[3b/4]</cyan> Applying GBRC postprocessing..." )
933961 from boxmot .postprocessing .gbrc import gbrc
934962 gbrc (mot_results_folder = exp_dir )
935963
936964
937- def run_trackeval (opt : argparse .Namespace , verbose : bool = True ) -> dict :
965+ def run_trackeval (args : argparse .Namespace , verbose : bool = True ) -> dict :
938966 """
939967 Runs the trackeval function to evaluate tracking results.
940968
941969 Args:
942- opt (Namespace): Parsed command line arguments.
970+ args (Namespace): Parsed command line arguments.
943971 verbose (bool): Whether to print results summary. Default True.
944972 """
945- seq_paths , seq_info = _collect_seq_info (opt .source )
946- annotations_dir = opt .source .parent / "annotations"
947- gt_folder = annotations_dir if annotations_dir .exists () else opt .source
973+ seq_paths , seq_info = _collect_seq_info (args .source )
974+ annotations_dir = args .source .parent / "annotations"
975+ gt_folder = annotations_dir if annotations_dir .exists () else args .source
948976
949977 if not seq_paths :
950- raise ValueError (f"No sequences with images found under { opt .source } " )
978+ raise ValueError (f"No sequences with images found under { args .source } " )
951979
952980 if annotations_dir .exists ():
953981 for seq_name in list (seq_info .keys ()):
@@ -967,22 +995,26 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
967995 seq_info [seq_name ] = max (seq_info .get (seq_name , 0 ) or 0 , max_frame )
968996 except Exception :
969997 LOGGER .warning (f"Failed to read annotation file { ann_file } for sequence length inference" )
970- save_dir = Path (opt .project ) / opt .name
971-
972- trackeval_results = trackeval (opt , seq_paths , save_dir , gt_folder , seq_info = seq_info )
998+ # runs/<dataset_name>/<name> when benchmark is set
999+ if getattr (args , "benchmark" , None ):
1000+ save_dir = Path (args .project ) / args .benchmark / args .name
1001+ else :
1002+ save_dir = Path (args .project ) / args .name
1003+
1004+ trackeval_results = trackeval (args , seq_paths , save_dir , gt_folder , seq_info = seq_info )
9731005 parsed_results = parse_mot_results (trackeval_results )
9741006
9751007 # Load config to filter classes
9761008 # Try to load config from benchmark name first, then fallback to source parent name
977- cfg_name = getattr (opt , 'benchmark' , str (opt .source .parent .name ))
1009+ cfg_name = getattr (args , 'benchmark' , str (args .source .parent .name ))
9781010 try :
9791011 cfg = load_dataset_cfg (cfg_name )
9801012 except FileNotFoundError :
9811013 # If config not found, try to find it by checking if source path ends with a known config name
9821014 # This handles cases where source is a custom path
9831015 found = False
9841016 for config_file in DATASET_CONFIGS .glob ("*.yaml" ):
985- if config_file .stem in str (opt .source ):
1017+ if config_file .stem in str (args .source ):
9861018 cfg = load_dataset_cfg (config_file .stem )
9871019 found = True
9881020 break
@@ -991,7 +1023,7 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
9911023 LOGGER .warning (f"Could not find dataset config for { cfg_name } . Class filtering might be incorrect." )
9921024 cfg = {}
9931025
994- # Filter parsed_results based on user provided classes (opt .classes)
1026+ # Filter parsed_results based on user provided classes (args .classes)
9951027 single_class_mode = False
9961028
9971029 # Priority 1: Benchmark config classes (overrides user classes)
@@ -1010,8 +1042,8 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
10101042 if len (bench_classes ) == 1 :
10111043 single_class_mode = True
10121044 # Priority 2: User provided classes
1013- elif hasattr (opt , 'classes' ) and opt .classes is not None :
1014- class_indices = opt .classes if isinstance (opt .classes , list ) else [opt .classes ]
1045+ elif hasattr (args , 'classes' ) and args .classes is not None :
1046+ class_indices = args .classes if isinstance (args .classes , list ) else [args .classes ]
10151047 user_classes = [COCO_CLASSES [int (i )] for i in class_indices ]
10161048 parsed_results = {k : v for k , v in parsed_results .items () if k in user_classes }
10171049 if len (user_classes ) == 1 :
@@ -1069,8 +1101,8 @@ def run_trackeval(opt: argparse.Namespace, verbose: bool = True) -> dict:
10691101
10701102 LOGGER .opt (colors = True ).info ("<blue>" + "=" * 105 + "</blue>" )
10711103
1072- if opt .ci :
1073- with open (opt .tracking_method + "_output.json" , "w" ) as outfile :
1104+ if args .ci :
1105+ with open (args .tracking_method + "_output.json" , "w" ) as outfile :
10741106 outfile .write (json .dumps (final_results ))
10751107
10761108 return final_results
0 commit comments