Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion boxmot/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def default_imgsz(yolo_name):
if is_ultralytics_model(yolo_name):
return [640, 640]
elif is_yolox_model(yolo_name):
return [800, 1440]
return [1080, 1920]
else:
return [640, 640]

Expand Down
73 changes: 68 additions & 5 deletions boxmot/detectors/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,65 @@
"yolox_x_MOT17_ablation.pt": "https://drive.google.com/uc?id=1iqhM-6V_r1FpOlOzrdP_Ejshgk0DxOob",
"yolox_x_MOT20_ablation.pt": "https://drive.google.com/uc?id=1H1BxOfinONCSdQKnjGq0XlRxVUo_4M8o",
"yolox_x_dancetrack_ablation.pt": "https://drive.google.com/uc?id=1ZKpYmFYCsRdXuOL60NRuc7VXAFYRskXB",
"yolox_x_visdrone.pt": "https://drive.google.com/uc?id=1ajehBs9enBHhuBqGIoQPGqkkzasE9d3o"
}


def _coerce_torch_dtype(dtype, fallback: torch.Tensor) -> torch.dtype:
"""Map YOLOX's dtype strings (e.g., 'torch.mps.FloatTensor') to real torch dtypes."""
if isinstance(dtype, torch.dtype):
return dtype
if isinstance(dtype, str):
lowered = dtype.lower()
if "bfloat16" in lowered:
return torch.bfloat16
if "float16" in lowered or "half" in lowered:
return torch.float16
# Default to the fallback tensor's dtype or float32.
return fallback.dtype if isinstance(fallback, torch.Tensor) else torch.float32


def _patch_yolox_head_decode_outputs_for_mps() -> None:
"""Monkeypatch YOLOXHead.decode_outputs to work on MPS (avoids .type with dtype strings)."""
try:
from yolox.models.yolo_head import YOLOXHead
from yolox.utils import meshgrid
except Exception:
return

if getattr(YOLOXHead, "_boxmot_mps_patched", False):
return

def decode_outputs(self, outputs, dtype):
dtype = _coerce_torch_dtype(dtype, outputs)
device = outputs.device
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = meshgrid([
torch.arange(hsize, device=device),
torch.arange(wsize, device=device),
])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(torch.full((*shape, 1), stride, device=device, dtype=grid.dtype))

grids = torch.cat(grids, dim=1).to(device=device, dtype=dtype)
strides = torch.cat(strides, dim=1).to(device=device, dtype=dtype)

outputs = outputs.clone()
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
return outputs

YOLOXHead.decode_outputs = decode_outputs
YOLOXHead._boxmot_mps_patched = True


_patch_yolox_head_decode_outputs_for_mps()


class YoloXStrategy:
"""YOLOX strategy for use with Ultralytics predictor workflow."""

Expand Down Expand Up @@ -135,9 +191,13 @@ def __init__(self, model, device, args):
# Custom trained models (e.g., yolox_x_MOT17_ablation) use the base architecture
if model_type == "yolox_n":
exp_name = "yolox_nano"
elif "_MOT" in model_type or "_dancetrack" in model_type:
# Extract base model: yolox_x_MOT17_ablation -> yolox_x
exp_name = model_type.split("_MOT")[0].split("_dancetrack")[0]
elif "_MOT" in model_type or "_dancetrack" in model_type or "_visdrone" in model_type:
# Extract base model: yolox_x_MOT17_ablation / yolox_x_visdrone -> yolox_x
exp_name = (
model_type.split("_MOT")[0]
.split("_dancetrack")[0]
.split("_visdrone")[0]
)
else:
exp_name = model_type
exp = get_exp(None, exp_name)
Expand All @@ -164,10 +224,13 @@ def __init__(self, model, device, args):
self.device = device
self.model = exp.get_model()
self.model.eval()
self.model.load_state_dict(ckpt["model"])
self.model = fuse_model(self.model)

# folow official yolox loading procedure
# https://github.com/Megvii-BaseDetection/YOLOX/blob/d872c71b/tools/eval.py#L148-L176
self.model.to(self.device)
self.model.eval()
self.model.load_state_dict(ckpt["model"])
self.model = fuse_model(self.model)
self.im_paths = []
self._preproc_data = []

Expand Down
11 changes: 10 additions & 1 deletion boxmot/engine/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def core_options(func):
help='IoU threshold for NMS'),
click.option('--device', default='',
help='cuda device(s), e.g. 0 or 0,1,2,3 or cpu'),
click.option('--batch-size', type=int, default=16, show_default=True,
help='micro-batch size for batched detection/embedding'),
click.option('--auto-batch/--no-auto-batch', default=True, show_default=True,
help='probe GPU memory with a dummy pass to pick a safe batch size'),
click.option('--resume/--no-resume', default=True, show_default=True,
help='resume detection/embedding generation from progress checkpoints'),
click.option('--read-threads', type=int, default=None,
help='CPU threads for image decoding; defaults to min(8, cpu_count)'),
click.option('--project', type=Path, default=ROOT / 'runs',
help='save results to project/name'),
click.option('--name', default='', help='save results to project/name'),
Expand Down Expand Up @@ -423,7 +431,8 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
'classes': parse_classes(classes),
'source': src,
'benchmark': bench,
'split': split}
'split': split,
'imgsz': [1088, 1920]}
args = SimpleNamespace(**params)
from boxmot.engine.evaluator import main as run_eval
run_eval(args)
Expand Down
Loading
Loading