|
24 | 24 | "yolox_x_MOT17_ablation.pt": "https://drive.google.com/uc?id=1iqhM-6V_r1FpOlOzrdP_Ejshgk0DxOob", |
25 | 25 | "yolox_x_MOT20_ablation.pt": "https://drive.google.com/uc?id=1H1BxOfinONCSdQKnjGq0XlRxVUo_4M8o", |
26 | 26 | "yolox_x_dancetrack_ablation.pt": "https://drive.google.com/uc?id=1ZKpYmFYCsRdXuOL60NRuc7VXAFYRskXB", |
| 27 | + "yolox_x_visdrone.pt": "https://drive.google.com/uc?id=1ajehBs9enBHhuBqGIoQPGqkkzasE9d3o" |
27 | 28 | } |
28 | 29 |
|
29 | 30 |
|
| 31 | +def _coerce_torch_dtype(dtype, fallback: torch.Tensor) -> torch.dtype: |
| 32 | + """Map YOLOX's dtype strings (e.g., 'torch.mps.FloatTensor') to real torch dtypes.""" |
| 33 | + if isinstance(dtype, torch.dtype): |
| 34 | + return dtype |
| 35 | + if isinstance(dtype, str): |
| 36 | + lowered = dtype.lower() |
| 37 | + if "bfloat16" in lowered: |
| 38 | + return torch.bfloat16 |
| 39 | + if "float16" in lowered or "half" in lowered: |
| 40 | + return torch.float16 |
| 41 | + # Default to the fallback tensor's dtype or float32. |
| 42 | + return fallback.dtype if isinstance(fallback, torch.Tensor) else torch.float32 |
| 43 | + |
| 44 | + |
| 45 | +def _patch_yolox_head_decode_outputs_for_mps() -> None: |
| 46 | + """Monkeypatch YOLOXHead.decode_outputs to work on MPS (avoids .type with dtype strings).""" |
| 47 | + try: |
| 48 | + from yolox.models.yolo_head import YOLOXHead |
| 49 | + from yolox.utils import meshgrid |
| 50 | + except Exception: |
| 51 | + return |
| 52 | + |
| 53 | + if getattr(YOLOXHead, "_boxmot_mps_patched", False): |
| 54 | + return |
| 55 | + |
| 56 | + def decode_outputs(self, outputs, dtype): |
| 57 | + dtype = _coerce_torch_dtype(dtype, outputs) |
| 58 | + device = outputs.device |
| 59 | + grids = [] |
| 60 | + strides = [] |
| 61 | + for (hsize, wsize), stride in zip(self.hw, self.strides): |
| 62 | + yv, xv = meshgrid([ |
| 63 | + torch.arange(hsize, device=device), |
| 64 | + torch.arange(wsize, device=device), |
| 65 | + ]) |
| 66 | + grid = torch.stack((xv, yv), 2).view(1, -1, 2) |
| 67 | + grids.append(grid) |
| 68 | + shape = grid.shape[:2] |
| 69 | + strides.append(torch.full((*shape, 1), stride, device=device, dtype=grid.dtype)) |
| 70 | + |
| 71 | + grids = torch.cat(grids, dim=1).to(device=device, dtype=dtype) |
| 72 | + strides = torch.cat(strides, dim=1).to(device=device, dtype=dtype) |
| 73 | + |
| 74 | + outputs = outputs.clone() |
| 75 | + outputs[..., :2] = (outputs[..., :2] + grids) * strides |
| 76 | + outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides |
| 77 | + return outputs |
| 78 | + |
| 79 | + YOLOXHead.decode_outputs = decode_outputs |
| 80 | + YOLOXHead._boxmot_mps_patched = True |
| 81 | + |
| 82 | + |
| 83 | +_patch_yolox_head_decode_outputs_for_mps() |
| 84 | + |
| 85 | + |
30 | 86 | class YoloXStrategy: |
31 | 87 | """YOLOX strategy for use with Ultralytics predictor workflow.""" |
32 | 88 |
|
@@ -135,9 +191,13 @@ def __init__(self, model, device, args): |
135 | 191 | # Custom trained models (e.g., yolox_x_MOT17_ablation) use the base architecture |
136 | 192 | if model_type == "yolox_n": |
137 | 193 | exp_name = "yolox_nano" |
138 | | - elif "_MOT" in model_type or "_dancetrack" in model_type: |
139 | | - # Extract base model: yolox_x_MOT17_ablation -> yolox_x |
140 | | - exp_name = model_type.split("_MOT")[0].split("_dancetrack")[0] |
| 194 | + elif "_MOT" in model_type or "_dancetrack" in model_type or "_visdrone" in model_type: |
| 195 | + # Extract base model: yolox_x_MOT17_ablation / yolox_x_visdrone -> yolox_x |
| 196 | + exp_name = ( |
| 197 | + model_type.split("_MOT")[0] |
| 198 | + .split("_dancetrack")[0] |
| 199 | + .split("_visdrone")[0] |
| 200 | + ) |
141 | 201 | else: |
142 | 202 | exp_name = model_type |
143 | 203 | exp = get_exp(None, exp_name) |
|
0 commit comments