Skip to content

Commit 7a64060

Browse files
committed
Support Online Tracking
1 parent a005f4b commit 7a64060

File tree

2 files changed

+132
-40
lines changed

2 files changed

+132
-40
lines changed

tracklab/engine/video.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ def track_dataset(self):
5151
self.callback("on_dataset_track_start")
5252
self.callback(
5353
"on_video_loop_start",
54-
video_metadata=pd.Series(name=self.video_filename),
54+
video_metadata=pd.Series({"name": self.video_filename}),
5555
video_idx=0,
5656
index=0,
5757
)
5858
detections = self.video_loop()
5959
self.callback(
6060
"on_video_loop_end",
61-
video_metadata=pd.Series(name=self.video_filename),
61+
video_metadata=pd.Series({"name": self.video_filename}),
6262
video_idx=0,
6363
detections=detections,
6464
)
@@ -81,6 +81,11 @@ def video_loop(self):
8181
# print('in offline.py, model_names: ', model_names)
8282
frame_idx = -1
8383
detections = pd.DataFrame()
84+
85+
# Initialize module callbacks at the start
86+
for model_name in model_names:
87+
self.callback("on_module_start", task=model_name, dataloader=[])
88+
8489
while video_cap.isOpened():
8590
frame_idx += 1
8691
ret, frame = video_cap.read()
@@ -89,10 +94,13 @@ def video_loop(self):
8994
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
9095
if not ret:
9196
break
92-
metadata = pd.Series({"id": frame_idx, "frame": frame_idx,
97+
base_metadata = pd.Series({"id": frame_idx, "frame": frame_idx,
9398
"video_id": video_filename}, name=frame_idx)
9499
self.callback("on_image_loop_start",
95-
image_metadata=metadata, image_idx=frame_idx, index=frame_idx)
100+
image_metadata=base_metadata, image_idx=frame_idx, index=frame_idx)
101+
102+
image_metadata = pd.DataFrame([base_metadata])
103+
96104
for model_name in model_names:
97105
model = self.models[model_name]
98106
if len(detections) > 0:
@@ -102,49 +110,64 @@ def video_loop(self):
102110
if model.level == "video":
103111
raise "Video-level not supported for online video tracking"
104112
elif model.level == "image":
105-
batch = model.preprocess(image=image, detections=dets, metadata=metadata)
113+
batch = model.preprocess(image=image, detections=dets, metadata=image_metadata.iloc[0])
106114
batch = type(model).collate_fn([(frame_idx, batch)])
107-
detections = self.default_step(batch, model_name, detections, metadata)
115+
detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata)
108116
elif model.level == "detection":
109117
for idx, detection in dets.iterrows():
110-
batch = model.preprocess(image=image, detection=detection, metadata=metadata)
118+
batch = model.preprocess(image=image, detection=detection, metadata=image_metadata.iloc[0])
111119
batch = type(model).collate_fn([(detection.name, batch)])
112-
detections = self.default_step(batch, model_name, detections, metadata)
120+
detections, image_metadata = self.default_step(batch, model_name, detections, image_metadata)
121+
113122
self.callback("on_image_loop_end",
114-
image_metadata=metadata, image=image,
123+
image_metadata=image_metadata.iloc[0], image=image,
115124
image_idx=frame_idx, detections=detections)
116125

126+
# Finalize module callbacks at the end
127+
for model_name in model_names:
128+
self.callback("on_module_end", task=model_name, detections=detections)
129+
117130
return detections
118131

119-
def default_step(self, batch: Any, task: str, detections: pd.DataFrame, metadata, **kwargs):
132+
def default_step(self, batch: Any, task: str, detections: pd.DataFrame, image_pred: pd.DataFrame, **kwargs):
120133
model = self.models[task]
121134
self.callback(f"on_module_step_start", task=task, batch=batch)
122135
idxs, batch = batch
123136
idxs = idxs.cpu() if isinstance(idxs, torch.Tensor) else idxs
124137
if model.level == "image":
125-
log.info(f"step : {idxs}")
126-
batch_metadatas = pd.DataFrame([metadata])
138+
log.info(f"step : {idxs} --- task : {task}")
139+
batch_metadatas = image_pred.loc[list(idxs)] # self.img_metadatas.loc[idxs]
127140
if len(detections) > 0:
128141
batch_input_detections = detections.loc[
129142
np.isin(detections.image_id, batch_metadatas.index)
130143
]
131144
else:
132145
batch_input_detections = detections
146+
133147
batch_detections = self.models[task].process(
134148
batch,
135149
batch_input_detections,
136150
batch_metadatas)
137151
else:
138152
batch_detections = detections.loc[idxs]
153+
if not image_pred.empty:
154+
batch_metadatas = image_pred.loc[np.isin(image_pred.index, batch_detections.image_id)]
155+
else:
156+
batch_metadatas = image_pred
139157
batch_detections = self.models[task].process(
140158
batch=batch,
141159
detections=batch_detections,
142-
metadatas=None,
160+
metadatas=batch_metadatas,
143161
**kwargs,
144162
)
163+
164+
if isinstance(batch_detections, tuple):
165+
batch_detections, batch_metadatas = batch_detections
166+
image_pred = merge_dataframes(image_pred, batch_metadatas)
167+
145168
detections = merge_dataframes(detections, batch_detections)
169+
146170
self.callback(
147171
f"on_module_step_end", task=task, batch=batch, detections=detections
148172
)
149-
return detections
150-
173+
return detections, image_pred

tracklab/visualization/visualization_engine.py

Lines changed: 94 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import cv2
88
import pandas as pd
9+
import numpy as np
10+
import platform
911

1012
from tracklab.callbacks import Progressbar, Callback
1113
from tracklab.visualization import Visualizer
@@ -23,6 +25,7 @@ class VisualizationEngine(Callback):
2325
`draw_detection`.
2426
save_images: whether to save the visualization as image files (.jpeg)
2527
save_videos: whether to save the visualization as video files (.mp4)
28+
show_online: whether to show online tracking in realtime (only work if the pipeline doesn't involve VideoLevelModule)
2629
process_n_videos: number of videos to visualize. Will visualize the first N videos.
2730
process_n_frames_by_video: number of frames per video to visualize. Will visualize
2831
frames every N/n frames (not first n frames)
@@ -32,6 +35,7 @@ def __init__(self,
3235
visualizers: Dict[str, Visualizer],
3336
save_images: bool = False,
3437
save_videos: bool = False,
38+
show_online: bool = False,
3539
video_fps: int = 25,
3640
process_n_videos: Optional[int] = None,
3741
process_n_frames_by_video: Optional[int] = None,
@@ -41,13 +45,17 @@ def __init__(self,
4145
self.save_dir = Path("visualization")
4246
self.save_images = save_images
4347
self.save_videos = save_videos
48+
self.show_online = show_online
4449
self.video_fps = video_fps
4550
self.max_videos = process_n_videos
4651
self.max_frames = process_n_frames_by_video
52+
self.windows = []
4753
for visualizer in visualizers.values():
4854
visualizer.post_init(**kwargs)
4955

5056
def on_dataset_track_end(self, engine: "TrackingEngine"):
57+
if self.show_online:
58+
cv2.destroyAllWindows()
5159
if self.save_videos or self.save_images:
5260
log.info(f"Visualization output at : {self.save_dir.absolute()}")
5361

@@ -58,31 +66,92 @@ def on_video_loop_end(self, engine, video_metadata, video_idx, detections,
5866
self.visualize(engine.tracker_state, video_idx, detections, image_pred, progress)
5967
progress.on_module_end(None, "vis", None)
6068

61-
"""
62-
#TODO implement the online visualization
63-
previous code:
64-
if self.cfg.show_online:
65-
tracker_state = engine.tracker_state
66-
if tracker_state.detections_gt is not None:
67-
ground_truths = tracker_state.detections_gt[
68-
tracker_state.detections_gt.image_id == image_metadata.name
69-
]
70-
else:
71-
ground_truths = None
72-
if len(detections) == 0:
73-
image = image
74-
else:
75-
detections = detections[detections.image_id == image_metadata.name]
76-
image = self.draw_frame(image_metadata,
77-
detections, ground_truths, "inf", image=image)
78-
if platform.system() == "Linux" and self.video_name not in self.windows:
79-
self.windows.append(self.video_name)
80-
cv2.namedWindow(str(self.video_name),
81-
cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
82-
cv2.resizeWindow(str(self.video_name), image.shape[1], image.shape[0])
83-
cv2.imshow(str(self.video_name), image)
84-
cv2.waitKey(1)
85-
"""
69+
def on_image_loop_end(self, engine, image_metadata, image, image_idx, detections):
70+
"""
71+
Handle real-time display during online video tracking.
72+
"""
73+
if not self.show_online:
74+
return
75+
76+
try:
77+
# Filter detections for current frame
78+
frame_detections = (
79+
detections[detections.image_id == image_metadata.name]
80+
if len(detections) > 0
81+
else pd.DataFrame()
82+
)
83+
84+
# Get ground truth (usually None for online tracking)
85+
ground_truths = pd.DataFrame()
86+
87+
# Create dummy image metadata for compatibility
88+
image_pred = pd.Series(
89+
{
90+
"lines": getattr(image_metadata, "lines", {}),
91+
"keypoints": getattr(image_metadata, "keypoints", {}),
92+
"file_path": f"frame_{image_idx:06d}.jpg", # Dummy path
93+
},
94+
name=image_metadata.name,
95+
)
96+
97+
image_gt = pd.Series(
98+
{
99+
"frame": image_idx,
100+
"nframes": -1, # Unknown total frames in online mode
101+
},
102+
name=image_metadata.name,
103+
)
104+
105+
# Draw frame with all visualizers
106+
display_image = self.draw_online_frame(
107+
image_metadata,
108+
image,
109+
frame_detections,
110+
ground_truths,
111+
image_pred,
112+
image_gt,
113+
nframes=-1,
114+
)
115+
116+
# Display the image
117+
video_name = str(engine.video_filename)
118+
if platform.system() == "Linux" and video_name not in self.windows:
119+
self.windows.append(video_name)
120+
cv2.namedWindow(video_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
121+
cv2.resizeWindow(
122+
video_name, display_image.shape[1], display_image.shape[0]
123+
)
124+
125+
# Convert RGB to BGR for OpenCV display
126+
cv2.imshow(video_name, display_image)
127+
cv2.waitKey(1) # Non-blocking wait
128+
129+
except Exception as e:
130+
log.warning(f"Error in online visualization: {e}")
131+
132+
def draw_online_frame(
133+
self,
134+
image_metadata,
135+
image,
136+
detections_pred,
137+
detections_gt,
138+
image_pred,
139+
image_gt,
140+
nframes,
141+
):
142+
"""Draw frame using all configured visualizers."""
143+
# Create a copy of the image to avoid modifying the original
144+
image = np.copy(image)
145+
146+
for visualizer in self.visualizers.values():
147+
try:
148+
visualizer.draw_frame(
149+
image, detections_pred, detections_gt, image_pred, image_gt
150+
)
151+
except Exception as e:
152+
log.warning(f"Visualizer {type(visualizer).__name__} raised error: {e}")
153+
154+
return final_patch(image)
86155

87156
def visualize(self, tracker_state: TrackerState, video_id, detections, image_preds, progress=None):
88157
image_metadatas = tracker_state.image_metadatas[tracker_state.image_metadatas.video_id == video_id]

0 commit comments

Comments
 (0)