Skip to content

Commit c877349

Browse files
single process batched evaluation for detector and reid (#2220)
* enable tuning for visdrone dataset * save latest * imgsz for evaluation set in cli.py * resume from frame id in dets txts * cleanup * clarification comment * allow out of image detections * enable timing when in evaluation mode * fix multi-class eval * fix bug when multi-sequence
1 parent c9fd7a1 commit c877349

File tree

9 files changed

+1358
-357
lines changed

9 files changed

+1358
-357
lines changed

boxmot/detectors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def default_imgsz(yolo_name):
3131
if is_ultralytics_model(yolo_name):
3232
return [640, 640]
3333
elif is_yolox_model(yolo_name):
34-
return [800, 1440]
34+
return [1080, 1920]
3535
else:
3636
return [640, 640]
3737

boxmot/detectors/yolox.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,65 @@
2424
"yolox_x_MOT17_ablation.pt": "https://drive.google.com/uc?id=1iqhM-6V_r1FpOlOzrdP_Ejshgk0DxOob",
2525
"yolox_x_MOT20_ablation.pt": "https://drive.google.com/uc?id=1H1BxOfinONCSdQKnjGq0XlRxVUo_4M8o",
2626
"yolox_x_dancetrack_ablation.pt": "https://drive.google.com/uc?id=1ZKpYmFYCsRdXuOL60NRuc7VXAFYRskXB",
27+
"yolox_x_visdrone.pt": "https://drive.google.com/uc?id=1ajehBs9enBHhuBqGIoQPGqkkzasE9d3o"
2728
}
2829

2930

31+
def _coerce_torch_dtype(dtype, fallback: torch.Tensor) -> torch.dtype:
32+
"""Map YOLOX's dtype strings (e.g., 'torch.mps.FloatTensor') to real torch dtypes."""
33+
if isinstance(dtype, torch.dtype):
34+
return dtype
35+
if isinstance(dtype, str):
36+
lowered = dtype.lower()
37+
if "bfloat16" in lowered:
38+
return torch.bfloat16
39+
if "float16" in lowered or "half" in lowered:
40+
return torch.float16
41+
# Default to the fallback tensor's dtype or float32.
42+
return fallback.dtype if isinstance(fallback, torch.Tensor) else torch.float32
43+
44+
45+
def _patch_yolox_head_decode_outputs_for_mps() -> None:
46+
"""Monkeypatch YOLOXHead.decode_outputs to work on MPS (avoids .type with dtype strings)."""
47+
try:
48+
from yolox.models.yolo_head import YOLOXHead
49+
from yolox.utils import meshgrid
50+
except Exception:
51+
return
52+
53+
if getattr(YOLOXHead, "_boxmot_mps_patched", False):
54+
return
55+
56+
def decode_outputs(self, outputs, dtype):
57+
dtype = _coerce_torch_dtype(dtype, outputs)
58+
device = outputs.device
59+
grids = []
60+
strides = []
61+
for (hsize, wsize), stride in zip(self.hw, self.strides):
62+
yv, xv = meshgrid([
63+
torch.arange(hsize, device=device),
64+
torch.arange(wsize, device=device),
65+
])
66+
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
67+
grids.append(grid)
68+
shape = grid.shape[:2]
69+
strides.append(torch.full((*shape, 1), stride, device=device, dtype=grid.dtype))
70+
71+
grids = torch.cat(grids, dim=1).to(device=device, dtype=dtype)
72+
strides = torch.cat(strides, dim=1).to(device=device, dtype=dtype)
73+
74+
outputs = outputs.clone()
75+
outputs[..., :2] = (outputs[..., :2] + grids) * strides
76+
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
77+
return outputs
78+
79+
YOLOXHead.decode_outputs = decode_outputs
80+
YOLOXHead._boxmot_mps_patched = True
81+
82+
83+
_patch_yolox_head_decode_outputs_for_mps()
84+
85+
3086
class YoloXStrategy:
3187
"""YOLOX strategy for use with Ultralytics predictor workflow."""
3288

@@ -135,9 +191,13 @@ def __init__(self, model, device, args):
135191
# Custom trained models (e.g., yolox_x_MOT17_ablation) use the base architecture
136192
if model_type == "yolox_n":
137193
exp_name = "yolox_nano"
138-
elif "_MOT" in model_type or "_dancetrack" in model_type:
139-
# Extract base model: yolox_x_MOT17_ablation -> yolox_x
140-
exp_name = model_type.split("_MOT")[0].split("_dancetrack")[0]
194+
elif "_MOT" in model_type or "_dancetrack" in model_type or "_visdrone" in model_type:
195+
# Extract base model: yolox_x_MOT17_ablation / yolox_x_visdrone -> yolox_x
196+
exp_name = (
197+
model_type.split("_MOT")[0]
198+
.split("_dancetrack")[0]
199+
.split("_visdrone")[0]
200+
)
141201
else:
142202
exp_name = model_type
143203
exp = get_exp(None, exp_name)
@@ -164,10 +224,13 @@ def __init__(self, model, device, args):
164224
self.device = device
165225
self.model = exp.get_model()
166226
self.model.eval()
167-
self.model.load_state_dict(ckpt["model"])
168-
self.model = fuse_model(self.model)
227+
228+
# folow official yolox loading procedure
229+
# https://github.com/Megvii-BaseDetection/YOLOX/blob/d872c71b/tools/eval.py#L148-L176
169230
self.model.to(self.device)
170231
self.model.eval()
232+
self.model.load_state_dict(ckpt["model"])
233+
self.model = fuse_model(self.model)
171234
self.im_paths = []
172235
self._preproc_data = []
173236

boxmot/engine/cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def core_options(func):
6161
help='IoU threshold for NMS'),
6262
click.option('--device', default='',
6363
help='cuda device(s), e.g. 0 or 0,1,2,3 or cpu'),
64+
click.option('--batch-size', type=int, default=16, show_default=True,
65+
help='micro-batch size for batched detection/embedding'),
66+
click.option('--auto-batch/--no-auto-batch', default=True, show_default=True,
67+
help='probe GPU memory with a dummy pass to pick a safe batch size'),
68+
click.option('--resume/--no-resume', default=True, show_default=True,
69+
help='resume detection/embedding generation from progress checkpoints'),
70+
click.option('--read-threads', type=int, default=None,
71+
help='CPU threads for image decoding; defaults to min(8, cpu_count)'),
6472
click.option('--project', type=Path, default=ROOT / 'runs',
6573
help='save results to project/name'),
6674
click.option('--name', default='', help='save results to project/name'),
@@ -423,7 +431,8 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
423431
'classes': parse_classes(classes),
424432
'source': src,
425433
'benchmark': bench,
426-
'split': split}
434+
'split': split,
435+
'imgsz': [1088, 1920]}
427436
args = SimpleNamespace(**params)
428437
from boxmot.engine.evaluator import main as run_eval
429438
run_eval(args)

0 commit comments

Comments
 (0)