Skip to content

Commit 96c6d61

Browse files
save latest
1 parent ed1e33b commit 96c6d61

File tree

4 files changed

+513
-68
lines changed

4 files changed

+513
-68
lines changed

boxmot/detectors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def default_imgsz(yolo_name):
3131
if is_ultralytics_model(yolo_name):
3232
return [640, 640]
3333
elif is_yolox_model(yolo_name):
34-
return [800, 1440]
34+
return [1080, 1920]
3535
else:
3636
return [640, 640]
3737

boxmot/detectors/yolox.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,65 @@
2424
"yolox_x_MOT17_ablation.pt": "https://drive.google.com/uc?id=1iqhM-6V_r1FpOlOzrdP_Ejshgk0DxOob",
2525
"yolox_x_MOT20_ablation.pt": "https://drive.google.com/uc?id=1H1BxOfinONCSdQKnjGq0XlRxVUo_4M8o",
2626
"yolox_x_dancetrack_ablation.pt": "https://drive.google.com/uc?id=1ZKpYmFYCsRdXuOL60NRuc7VXAFYRskXB",
27+
"yolox_x_visdrone.pt": "https://drive.google.com/uc?id=1ajehBs9enBHhuBqGIoQPGqkkzasE9d3o"
2728
}
2829

2930

31+
def _coerce_torch_dtype(dtype, fallback: torch.Tensor) -> torch.dtype:
32+
"""Map YOLOX's dtype strings (e.g., 'torch.mps.FloatTensor') to real torch dtypes."""
33+
if isinstance(dtype, torch.dtype):
34+
return dtype
35+
if isinstance(dtype, str):
36+
lowered = dtype.lower()
37+
if "bfloat16" in lowered:
38+
return torch.bfloat16
39+
if "float16" in lowered or "half" in lowered:
40+
return torch.float16
41+
# Default to the fallback tensor's dtype or float32.
42+
return fallback.dtype if isinstance(fallback, torch.Tensor) else torch.float32
43+
44+
45+
def _patch_yolox_head_decode_outputs_for_mps() -> None:
46+
"""Monkeypatch YOLOXHead.decode_outputs to work on MPS (avoids .type with dtype strings)."""
47+
try:
48+
from yolox.models.yolo_head import YOLOXHead
49+
from yolox.utils import meshgrid
50+
except Exception:
51+
return
52+
53+
if getattr(YOLOXHead, "_boxmot_mps_patched", False):
54+
return
55+
56+
def decode_outputs(self, outputs, dtype):
57+
dtype = _coerce_torch_dtype(dtype, outputs)
58+
device = outputs.device
59+
grids = []
60+
strides = []
61+
for (hsize, wsize), stride in zip(self.hw, self.strides):
62+
yv, xv = meshgrid([
63+
torch.arange(hsize, device=device),
64+
torch.arange(wsize, device=device),
65+
])
66+
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
67+
grids.append(grid)
68+
shape = grid.shape[:2]
69+
strides.append(torch.full((*shape, 1), stride, device=device, dtype=grid.dtype))
70+
71+
grids = torch.cat(grids, dim=1).to(device=device, dtype=dtype)
72+
strides = torch.cat(strides, dim=1).to(device=device, dtype=dtype)
73+
74+
outputs = outputs.clone()
75+
outputs[..., :2] = (outputs[..., :2] + grids) * strides
76+
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
77+
return outputs
78+
79+
YOLOXHead.decode_outputs = decode_outputs
80+
YOLOXHead._boxmot_mps_patched = True
81+
82+
83+
_patch_yolox_head_decode_outputs_for_mps()
84+
85+
3086
class YoloXStrategy:
3187
"""YOLOX strategy for use with Ultralytics predictor workflow."""
3288

@@ -135,9 +191,13 @@ def __init__(self, model, device, args):
135191
# Custom trained models (e.g., yolox_x_MOT17_ablation) use the base architecture
136192
if model_type == "yolox_n":
137193
exp_name = "yolox_nano"
138-
elif "_MOT" in model_type or "_dancetrack" in model_type:
139-
# Extract base model: yolox_x_MOT17_ablation -> yolox_x
140-
exp_name = model_type.split("_MOT")[0].split("_dancetrack")[0]
194+
elif "_MOT" in model_type or "_dancetrack" in model_type or "_visdrone" in model_type:
195+
# Extract base model: yolox_x_MOT17_ablation / yolox_x_visdrone -> yolox_x
196+
exp_name = (
197+
model_type.split("_MOT")[0]
198+
.split("_dancetrack")[0]
199+
.split("_visdrone")[0]
200+
)
141201
else:
142202
exp_name = model_type
143203
exp = get_exp(None, exp_name)

boxmot/engine/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def core_options(func):
6161
help='IoU threshold for NMS'),
6262
click.option('--device', default='',
6363
help='cuda device(s), e.g. 0 or 0,1,2,3 or cpu'),
64+
click.option('--batch-size', type=int, default=16, show_default=True,
65+
help='micro-batch size for batched detection/embedding'),
66+
click.option('--auto-batch/--no-auto-batch', default=True, show_default=True,
67+
help='probe GPU memory with a dummy pass to pick a safe batch size'),
68+
click.option('--resume/--no-resume', default=True, show_default=True,
69+
help='resume detection/embedding generation from progress checkpoints'),
70+
click.option('--read-threads', type=int, default=None,
71+
help='CPU threads for image decoding; defaults to min(8, cpu_count)'),
6472
click.option('--project', type=Path, default=ROOT / 'runs',
6573
help='save results to project/name'),
6674
click.option('--name', default='', help='save results to project/name'),

0 commit comments

Comments
 (0)