Skip to content

Commit 99b8918

Browse files
committed
Support Online Tracking
1 parent 90cae75 commit 99b8918

File tree

2 files changed

+140
-41
lines changed

2 files changed

+140
-41
lines changed

tracklab/engine/video.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ def __init__(
4949
def track_dataset(self):
5050
"""Run tracking on complete dataset."""
5151
self.callback("on_dataset_track_start")
52+
5253
self.callback(
5354
"on_video_loop_start",
54-
video_metadata=pd.Series(name=self.video_filename),
55+
video_metadata=pd.Series({"name": self.video_filename}),
5556
video_idx=0,
5657
index=0,
5758
)
5859
detections = self.video_loop()
5960
self.callback(
6061
"on_video_loop_end",
61-
video_metadata=pd.Series(name=self.video_filename),
62+
video_metadata=pd.Series({"name": self.video_filename}),
6263
video_idx=0,
6364
detections=detections,
6465
)
@@ -81,6 +82,15 @@ def video_loop(self):
8182
# print('in offline.py, model_names: ', model_names)
8283
frame_idx = -1
8384
detections = pd.DataFrame()
85+
86+
# Initialize module callbacks at the start
87+
for model_name in model_names:
88+
dummy_dataloader = []
89+
self.callback("on_module_start", task=model_name, dataloader=dummy_dataloader)
90+
91+
# Initialize image metadata for the current frame
92+
image_metadata = pd.DataFrame()
93+
8494
while video_cap.isOpened():
8595
frame_idx += 1
8696
ret, frame = video_cap.read()
@@ -89,10 +99,20 @@ def video_loop(self):
8999
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
90100
if not ret:
91101
break
92-
metadata = pd.Series({"id": frame_idx, "frame": frame_idx,
93-
"video_id": video_filename}, name=frame_idx)
102+
103+
# Create base metadata for this frame
104+
base_metadata = pd.Series({
105+
"id": frame_idx,
106+
"frame": frame_idx,
107+
"video_id": video_filename
108+
}, name=frame_idx)
109+
110+
# Reset image metadata for this frame with base metadata
111+
image_metadata = pd.DataFrame([base_metadata])
112+
94113
self.callback("on_image_loop_start",
95-
image_metadata=metadata, image_idx=frame_idx, index=frame_idx)
114+
image_metadata=base_metadata, image_idx=frame_idx, index=frame_idx)
115+
96116
for model_name in model_names:
97117
model = self.models[model_name]
98118
if len(detections) > 0:
@@ -102,49 +122,59 @@ def video_loop(self):
102122
if model.level == "video":
103123
raise "Video-level not supported for online video tracking"
104124
elif model.level == "image":
105-
batch = model.preprocess(image=image, detections=dets, metadata=metadata)
125+
batch = model.preprocess(image=image, detections=dets, metadata=image_metadata.iloc[0])
106126
batch = type(model).collate_fn([(frame_idx, batch)])
107-
detections = self.default_step(batch, model_name, detections, metadata)
127+
detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata)
108128
elif model.level == "detection":
109129
for idx, detection in dets.iterrows():
110-
batch = model.preprocess(image=image, detection=detection, metadata=metadata)
130+
batch = model.preprocess(image=image, detection=detection, metadata=image_metadata.iloc[0])
111131
batch = type(model).collate_fn([(detection.name, batch)])
112-
detections = self.default_step(batch, model_name, detections, metadata)
132+
detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata)
113133
self.callback("on_image_loop_end",
114-
image_metadata=metadata, image=image,
134+
image_metadata=image_metadata.iloc[0], image=image,
115135
image_idx=frame_idx, detections=detections)
116136

137+
# Finalize module callbacks at the end
138+
for model_name in model_names:
139+
self.callback("on_module_end", task=model_name, detections=detections)
140+
117141
return detections
118142

119-
def default_step(self, batch: Any, task: str, detections: pd.DataFrame, metadata, **kwargs):
143+
def default_step(self, batch: Any, task: str, detections: pd.DataFrame, image_metadata: pd.DataFrame, **kwargs):
120144
model = self.models[task]
121145
self.callback(f"on_module_step_start", task=task, batch=batch)
122146
idxs, batch = batch
123147
idxs = idxs.cpu() if isinstance(idxs, torch.Tensor) else idxs
124148
if model.level == "image":
125149
log.info(f"step : {idxs}")
126-
batch_metadatas = pd.DataFrame([metadata])
127150
if len(detections) > 0:
128151
batch_input_detections = detections.loc[
129-
np.isin(detections.image_id, batch_metadatas.index)
152+
np.isin(detections.image_id, image_metadata.index)
130153
]
131154
else:
132155
batch_input_detections = detections
133156
batch_detections = self.models[task].process(
134157
batch,
135158
batch_input_detections,
136-
batch_metadatas)
159+
image_metadata)
137160
else:
138161
batch_detections = detections.loc[idxs]
139162
batch_detections = self.models[task].process(
140163
batch=batch,
141164
detections=batch_detections,
142-
metadatas=None,
165+
metadatas=image_metadata,
143166
**kwargs,
144167
)
168+
169+
# Handle tuple return values (some modules return (detections, metadatas))
170+
if isinstance(batch_detections, tuple):
171+
batch_detections, batch_metadatas = batch_detections
172+
# Update image metadata with outputs from this module
173+
image_metadata = merge_dataframes(image_metadata, batch_metadatas)
174+
145175
detections = merge_dataframes(detections, batch_detections)
146176
self.callback(
147177
f"on_module_step_end", task=task, batch=batch, detections=detections
148178
)
149-
return detections
179+
return detections, image_metadata
150180

tracklab/visualization/visualization_engine.py

Lines changed: 94 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import cv2
88
import pandas as pd
9+
import numpy as np
10+
import platform
911

1012
from tracklab.callbacks import Progressbar, Callback
1113
from tracklab.visualization import Visualizer
@@ -23,6 +25,7 @@ class VisualizationEngine(Callback):
2325
`draw_detection`.
2426
save_images: whether to save the visualization as image files (.jpeg)
2527
save_videos: whether to save the visualization as video files (.mp4)
28+
show_online: whether to show online tracking in realtime (only work if the pipeline doesn't involve VideoLevelModule)
2629
process_n_videos: number of videos to visualize. Will visualize the first N videos.
2730
process_n_frames_by_video: number of frames per video to visualize. Will visualize
2831
frames every N/n frames (not first n frames)
@@ -32,6 +35,7 @@ def __init__(self,
3235
visualizers: Dict[str, Visualizer],
3336
save_images: bool = False,
3437
save_videos: bool = False,
38+
show_online: bool = False,
3539
video_fps: int = 25,
3640
process_n_videos: Optional[int] = None,
3741
process_n_frames_by_video: Optional[int] = None,
@@ -41,13 +45,17 @@ def __init__(self,
4145
self.save_dir = Path("visualization")
4246
self.save_images = save_images
4347
self.save_videos = save_videos
48+
self.show_online = show_online
4449
self.video_fps = video_fps
4550
self.max_videos = process_n_videos
4651
self.max_frames = process_n_frames_by_video
52+
self.windows = []
4753
for visualizer in visualizers.values():
4854
visualizer.post_init(**kwargs)
4955

5056
def on_dataset_track_end(self, engine: "TrackingEngine"):
57+
if self.show_online:
58+
cv2.destroyAllWindows()
5159
if self.save_videos or self.save_images:
5260
log.info(f"Visualization output at : {self.save_dir.absolute()}")
5361

@@ -58,31 +66,92 @@ def on_video_loop_end(self, engine, video_metadata, video_idx, detections,
5866
self.visualize(engine.tracker_state, video_idx, detections, image_pred, progress)
5967
progress.on_module_end(None, "vis", None)
6068

61-
"""
62-
#TODO implement the online visualization
63-
previous code:
64-
if self.cfg.show_online:
65-
tracker_state = engine.tracker_state
66-
if tracker_state.detections_gt is not None:
67-
ground_truths = tracker_state.detections_gt[
68-
tracker_state.detections_gt.image_id == image_metadata.name
69-
]
70-
else:
71-
ground_truths = None
72-
if len(detections) == 0:
73-
image = image
74-
else:
75-
detections = detections[detections.image_id == image_metadata.name]
76-
image = self.draw_frame(image_metadata,
77-
detections, ground_truths, "inf", image=image)
78-
if platform.system() == "Linux" and self.video_name not in self.windows:
79-
self.windows.append(self.video_name)
80-
cv2.namedWindow(str(self.video_name),
81-
cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
82-
cv2.resizeWindow(str(self.video_name), image.shape[1], image.shape[0])
83-
cv2.imshow(str(self.video_name), image)
84-
cv2.waitKey(1)
85-
"""
69+
def on_image_loop_end(self, engine, image_metadata, image, image_idx, detections):
70+
"""
71+
Handle real-time display during online video tracking.
72+
"""
73+
if not self.show_online:
74+
return
75+
76+
try:
77+
# Filter detections for current frame
78+
frame_detections = (
79+
detections[detections.image_id == image_metadata.name]
80+
if len(detections) > 0
81+
else pd.DataFrame()
82+
)
83+
84+
# Get ground truth (usually None for online tracking)
85+
ground_truths = pd.DataFrame()
86+
87+
# Create dummy image metadata for compatibility
88+
image_pred = pd.Series(
89+
{
90+
"lines": getattr(image_metadata, "lines", {}),
91+
"keypoints": getattr(image_metadata, "keypoints", {}),
92+
"file_path": f"frame_{image_idx:06d}.jpg", # Dummy path
93+
},
94+
name=image_metadata.name,
95+
)
96+
97+
image_gt = pd.Series(
98+
{
99+
"frame": image_idx,
100+
"nframes": -1, # Unknown total frames in online mode
101+
},
102+
name=image_metadata.name,
103+
)
104+
105+
# Draw frame with all visualizers
106+
display_image = self.draw_online_frame(
107+
image_metadata,
108+
image,
109+
frame_detections,
110+
ground_truths,
111+
image_pred,
112+
image_gt,
113+
nframes=-1,
114+
)
115+
116+
# Display the image
117+
video_name = str(engine.video_filename)
118+
if platform.system() == "Linux" and video_name not in self.windows:
119+
self.windows.append(video_name)
120+
cv2.namedWindow(video_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
121+
cv2.resizeWindow(
122+
video_name, display_image.shape[1], display_image.shape[0]
123+
)
124+
125+
# Convert RGB to BGR for OpenCV display
126+
cv2.imshow(video_name, display_image)
127+
cv2.waitKey(1) # Non-blocking wait
128+
129+
except Exception as e:
130+
log.warning(f"Error in online visualization: {e}")
131+
132+
def draw_online_frame(
133+
self,
134+
image_metadata,
135+
image,
136+
detections_pred,
137+
detections_gt,
138+
image_pred,
139+
image_gt,
140+
nframes,
141+
):
142+
"""Draw frame using all configured visualizers."""
143+
# Create a copy of the image to avoid modifying the original
144+
image = np.copy(image)
145+
146+
for visualizer in self.visualizers.values():
147+
try:
148+
visualizer.draw_frame(
149+
image, detections_pred, detections_gt, image_pred, image_gt
150+
)
151+
except Exception as e:
152+
log.warning(f"Visualizer {type(visualizer).__name__} raised error: {e}")
153+
154+
return final_patch(image)
86155

87156
def visualize(self, tracker_state: TrackerState, video_id, detections, image_preds, progress=None):
88157
image_metadatas = tracker_state.image_metadatas[tracker_state.image_metadatas.video_id == video_id]

0 commit comments

Comments
 (0)