@@ -49,16 +49,17 @@ def __init__(
4949 def track_dataset (self ):
5050 """Run tracking on complete dataset."""
5151 self .callback ("on_dataset_track_start" )
52+
5253 self .callback (
5354 "on_video_loop_start" ,
54- video_metadata = pd .Series (name = self .video_filename ),
55+ video_metadata = pd .Series ({ " name" : self .video_filename } ),
5556 video_idx = 0 ,
5657 index = 0 ,
5758 )
5859 detections = self .video_loop ()
5960 self .callback (
6061 "on_video_loop_end" ,
61- video_metadata = pd .Series (name = self .video_filename ),
62+ video_metadata = pd .Series ({ " name" : self .video_filename } ),
6263 video_idx = 0 ,
6364 detections = detections ,
6465 )
@@ -81,6 +82,15 @@ def video_loop(self):
8182 # print('in offline.py, model_names: ', model_names)
8283 frame_idx = - 1
8384 detections = pd .DataFrame ()
85+
86+ # Initialize module callbacks at the start
87+ for model_name in model_names :
88+ dummy_dataloader = []
89+ self .callback ("on_module_start" , task = model_name , dataloader = dummy_dataloader )
90+
91+ # Initialize image metadata for the current frame
92+ image_metadata = pd .DataFrame ()
93+
8494 while video_cap .isOpened ():
8595 frame_idx += 1
8696 ret , frame = video_cap .read ()
@@ -89,10 +99,20 @@ def video_loop(self):
8999 image = cv2 .cvtColor (frame , cv2 .COLOR_BGR2RGB )
90100 if not ret :
91101 break
92- metadata = pd .Series ({"id" : frame_idx , "frame" : frame_idx ,
93- "video_id" : video_filename }, name = frame_idx )
102+
103+ # Create base metadata for this frame
104+ base_metadata = pd .Series ({
105+ "id" : frame_idx ,
106+ "frame" : frame_idx ,
107+ "video_id" : video_filename
108+ }, name = frame_idx )
109+
110+ # Reset image metadata for this frame with base metadata
111+ image_metadata = pd .DataFrame ([base_metadata ])
112+
94113 self .callback ("on_image_loop_start" ,
95- image_metadata = metadata , image_idx = frame_idx , index = frame_idx )
114+ image_metadata = base_metadata , image_idx = frame_idx , index = frame_idx )
115+
96116 for model_name in model_names :
97117 model = self .models [model_name ]
98118 if len (detections ) > 0 :
@@ -102,49 +122,59 @@ def video_loop(self):
102122 if model .level == "video" :
103123 raise "Video-level not supported for online video tracking"
104124 elif model .level == "image" :
105- batch = model .preprocess (image = image , detections = dets , metadata = metadata )
125+ batch = model .preprocess (image = image , detections = dets , metadata = image_metadata . iloc [ 0 ] )
106126 batch = type (model ).collate_fn ([(frame_idx , batch )])
107- detections = self .default_step (batch , model_name , detections , metadata )
127+ detections , image_metadata = self .default_step (batch , model_name , detections , image_metadata )
108128 elif model .level == "detection" :
109129 for idx , detection in dets .iterrows ():
110- batch = model .preprocess (image = image , detection = detection , metadata = metadata )
130+ batch = model .preprocess (image = image , detection = detection , metadata = image_metadata . iloc [ 0 ] )
111131 batch = type (model ).collate_fn ([(detection .name , batch )])
112- detections = self .default_step (batch , model_name , detections , metadata )
132+ detections , image_metadata = self .default_step (batch , model_name , detections , image_metadata )
113133 self .callback ("on_image_loop_end" ,
114- image_metadata = metadata , image = image ,
134+ image_metadata = image_metadata . iloc [ 0 ] , image = image ,
115135 image_idx = frame_idx , detections = detections )
116136
137+ # Finalize module callbacks at the end
138+ for model_name in model_names :
139+ self .callback ("on_module_end" , task = model_name , detections = detections )
140+
117141 return detections
118142
119- def default_step (self , batch : Any , task : str , detections : pd .DataFrame , metadata , ** kwargs ):
143+ def default_step (self , batch : Any , task : str , detections : pd .DataFrame , image_metadata : pd . DataFrame , ** kwargs ):
120144 model = self .models [task ]
121145 self .callback (f"on_module_step_start" , task = task , batch = batch )
122146 idxs , batch = batch
123147 idxs = idxs .cpu () if isinstance (idxs , torch .Tensor ) else idxs
124148 if model .level == "image" :
125149 log .info (f"step : { idxs } " )
126- batch_metadatas = pd .DataFrame ([metadata ])
127150 if len (detections ) > 0 :
128151 batch_input_detections = detections .loc [
129- np .isin (detections .image_id , batch_metadatas .index )
152+ np .isin (detections .image_id , image_metadata .index )
130153 ]
131154 else :
132155 batch_input_detections = detections
133156 batch_detections = self .models [task ].process (
134157 batch ,
135158 batch_input_detections ,
136- batch_metadatas )
159+ image_metadata )
137160 else :
138161 batch_detections = detections .loc [idxs ]
139162 batch_detections = self .models [task ].process (
140163 batch = batch ,
141164 detections = batch_detections ,
142- metadatas = None ,
165+ metadatas = image_metadata ,
143166 ** kwargs ,
144167 )
168+
169+ # Handle tuple return values (some modules return (detections, metadatas))
170+ if isinstance (batch_detections , tuple ):
171+ batch_detections , batch_metadatas = batch_detections
172+ # Update image metadata with outputs from this module
173+ image_metadata = merge_dataframes (image_metadata , batch_metadatas )
174+
145175 detections = merge_dataframes (detections , batch_detections )
146176 self .callback (
147177 f"on_module_step_end" , task = task , batch = batch , detections = detections
148178 )
149- return detections
179+ return detections , image_metadata
150180
0 commit comments