Skip to content

Commit 0d8114e

Browse files
refactor (#2202)
1 parent 538dc16 commit 0d8114e

File tree

4 files changed

+105
-74
lines changed

4 files changed

+105
-74
lines changed

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,20 @@ Where:
9393

9494
**Quick Examples:**
9595
```bash
96-
# Track with webcam
97-
boxmot track yolov8n osnet_x0_25_msmt17 deepocsort --source 0 --show
96+
# Track with webcam, save results, show basic results
97+
boxmot track yolov8n osnet_x0_25_msmt17 deepocsort --source 0 --show --save
9898

99-
# Track a video file
100-
boxmot track yolov8n osnet_x0_25_msmt17 botsort --source video.mp4 --save
99+
# Track a video file, save results, show trajectories + lost tracks
100+
boxmot track yolov8n osnet_x0_25_msmt17 botsort --source video.mp4 --save --show-trajectories --show-lost
101101

102102
# Evaluate on MOT dataset
103-
boxmot eval yolox_x_MOT17_ablation lmbn_n_duke botsort --source MOT17-ablation --classes 0,2 --source MOT17-ablation
103+
boxmot eval yolox_x_MOT17_ablation lmbn_n_duke botsort --source MOT17-ablation
104104

105-
# Tune tracker hyperparameters
106-
boxmot eval yolox_x_dancetrack_ablation lmbn_n_duke botsort --source MOT17-ablation --classes 0,2 --source dancetrack-ablation --n-trials 1000
105+
# Tune ocsort's hyperparameters for dancetrack
106+
boxmot tune yolox_x_dancetrack_ablation lmbn_n_duke ocsort --source dancetrack-ablation --n-trials 10
107107

108-
# Export ReID model
109-
boxmot export --weights osnet_x0_25_msmt17.pt --include onnx --include openvino --dynamic
108+
# Export ReID model with dynamic sized input
109+
boxmot export --weights osnet_x0_25_msmt17.pt --include onnx --include engine dynamic
110110
```
111111

112112
## 🐍 PYTHON

boxmot/engine/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def core_options(func):
9191
click.option('--show-trajectories', is_flag=True,
9292
help='overlay past trajectories'),
9393
click.option('--show-lost', is_flag=True,
94-
help='show lost and removed tracks'),
94+
help='show lost tracks'),
9595
click.option('--save-txt', is_flag=True,
9696
help='save results to a .txt file'),
9797
click.option('--save-crop', is_flag=True,

boxmot/trackers/basetracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from boxmot.utils.visualization import VisualizationMixin
1111

1212

13-
class BaseTracker(ABC, VisualizationMixin):
13+
class BaseTracker(VisualizationMixin):
1414
def __init__(
1515
self,
1616
det_thresh: float = 0.3,

boxmot/utils/visualization.py

Lines changed: 94 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import colorsys
22
import hashlib
3+
from abc import ABC, abstractmethod
4+
35
import cv2 as cv
46
import numpy as np
57

6-
class VisualizationMixin:
8+
9+
class BaseVisualization(ABC):
710
"""
8-
Mixin class for visualization methods in BaseTracker.
11+
Abstract base class for visualization methods in BaseTracker.
912
"""
1013

1114
def id_to_color(
@@ -176,67 +179,9 @@ def _infer_state(self, a):
176179

177180
return "confirmed"
178181

182+
@abstractmethod
179183
def _display_groups(self):
180-
"""
181-
Yield groups of (tracks, forced_state, style) ready for drawing.
182-
If ByteTrack-style lists exist, use them with styles and TTLs.
183-
Otherwise, fall back to all active tracks and per-track inferred state.
184-
"""
185-
lost_list = getattr(self, "lost_stracks", None)
186-
removed_list = getattr(self, "removed_stracks", None)
187-
188-
# Maintain internal frame index for TTL accounting
189-
self._plot_frame_idx += 1
190-
now = self._plot_frame_idx
191-
192-
ttl = int(max(0, getattr(self, "removed_display_frames", self.removed_display_frames)))
193-
194-
if (lost_list is not None) or (removed_list is not None):
195-
# Active
196-
yield (self._all_active_tracks(), "confirmed", "solid")
197-
198-
# Lost (dashed, orange)
199-
if lost_list:
200-
yield (list(lost_list), "predicted", "dashed")
201-
202-
# Removed (gray, solid), with TTL + tombstone
203-
if removed_list and ttl > 0:
204-
filtered_removed = []
205-
for a in removed_list:
206-
if not getattr(a, "history_observations", None):
207-
continue
208-
sf = int(getattr(a, "start_frame", getattr(a, "birth_frame", -1)))
209-
rid = int(getattr(a, "id"))
210-
key = (rid, sf) if sf >= 0 else rid
211-
212-
if key in self._removed_expired:
213-
continue
214-
215-
if key not in self._removed_first_seen:
216-
self._removed_first_seen[key] = now
217-
218-
if (now - self._removed_first_seen[key]) < ttl:
219-
filtered_removed.append(a)
220-
else:
221-
self._removed_expired.add(key)
222-
223-
if filtered_removed:
224-
yield (filtered_removed, "removed", "solid")
225-
226-
# Optional: simple memory cap
227-
if len(self._removed_expired) > 10000:
228-
horizon = getattr(self, "removed_tombstone_horizon", 10000)
229-
cutoff = now - max(ttl, 1) - horizon
230-
to_drop = [k for k, t0 in self._removed_first_seen.items() if t0 < cutoff]
231-
for k in to_drop:
232-
self._removed_first_seen.pop(k, None)
233-
self._removed_expired.discard(k)
234-
235-
else:
236-
# Generic fallback: only active tracks; state per track
237-
active_tracks = self._all_active_tracks()
238-
if active_tracks:
239-
yield (active_tracks, None, "dashed")
184+
pass
240185

241186
def _draw_track(self, img, a, forced_state, style, thickness, fontscale, show_trajectories):
242187
if not getattr(a, "history_observations", None):
@@ -286,7 +231,7 @@ def plot_results(
286231
show_trajectories: bool,
287232
thickness: int = 2,
288233
fontscale: float = 0.5,
289-
show_lost: bool = True,
234+
show_lost: bool = False,
290235
) -> np.ndarray:
291236
"""
292237
Visualizes the trajectories of all active tracks on the image.
@@ -311,3 +256,89 @@ def plot_results(
311256
show_trajectories=show_trajectories,
312257
)
313258
return img
259+
260+
261+
class ExplicitStateVisualization(BaseVisualization):
262+
"""
263+
Visualization for trackers that maintain explicit lists for lost and removed tracks.
264+
"""
265+
266+
def _display_groups(self):
267+
lost_list = getattr(self, "lost_stracks", None)
268+
removed_list = getattr(self, "removed_stracks", None)
269+
270+
# Maintain internal frame index for TTL accounting
271+
self._plot_frame_idx += 1
272+
now = self._plot_frame_idx
273+
274+
ttl = int(max(0, getattr(self, "removed_display_frames", self.removed_display_frames)))
275+
276+
# Active
277+
yield (self._all_active_tracks(), "confirmed", "solid")
278+
279+
# Lost (dashed, orange)
280+
if lost_list:
281+
yield (list(lost_list), "predicted", "dashed")
282+
283+
# Removed (gray, solid), with TTL + tombstone
284+
if removed_list and ttl > 0:
285+
filtered_removed = []
286+
for a in removed_list:
287+
if not getattr(a, "history_observations", None):
288+
continue
289+
sf = int(getattr(a, "start_frame", getattr(a, "birth_frame", -1)))
290+
rid = int(getattr(a, "id"))
291+
key = (rid, sf) if sf >= 0 else rid
292+
293+
if key in self._removed_expired:
294+
continue
295+
296+
if key not in self._removed_first_seen:
297+
self._removed_first_seen[key] = now
298+
299+
if (now - self._removed_first_seen[key]) < ttl:
300+
filtered_removed.append(a)
301+
else:
302+
self._removed_expired.add(key)
303+
304+
if filtered_removed:
305+
yield (filtered_removed, "removed", "solid")
306+
307+
# Optional: simple memory cap
308+
if len(self._removed_expired) > 10000:
309+
horizon = getattr(self, "removed_tombstone_horizon", 10000)
310+
cutoff = now - max(ttl, 1) - horizon
311+
to_drop = [k for k, t0 in self._removed_first_seen.items() if t0 < cutoff]
312+
for k in to_drop:
313+
self._removed_first_seen.pop(k, None)
314+
self._removed_expired.discard(k)
315+
316+
317+
class InferredStateVisualization(BaseVisualization):
318+
"""
319+
Visualization for trackers that only expose active tracks and state is inferred.
320+
"""
321+
322+
def _display_groups(self):
323+
# Maintain internal frame index for TTL accounting
324+
self._plot_frame_idx += 1
325+
326+
# Generic fallback: only active tracks; state per track
327+
active_tracks = self._all_active_tracks()
328+
if active_tracks:
329+
yield (active_tracks, None, "dashed")
330+
331+
332+
class VisualizationMixin(BaseVisualization):
333+
"""
334+
Mixin class for visualization methods in BaseTracker.
335+
"""
336+
337+
def _display_groups(self):
338+
lost_list = getattr(self, "lost_stracks", None)
339+
removed_list = getattr(self, "removed_stracks", None)
340+
341+
if (lost_list is not None) or (removed_list is not None):
342+
return ExplicitStateVisualization._display_groups(self)
343+
else:
344+
return InferredStateVisualization._display_groups(self)

0 commit comments

Comments
 (0)