2424 "yolox_x_MOT17_ablation.pt" : "https://drive.google.com/uc?id=1iqhM-6V_r1FpOlOzrdP_Ejshgk0DxOob" ,
2525 "yolox_x_MOT20_ablation.pt" : "https://drive.google.com/uc?id=1H1BxOfinONCSdQKnjGq0XlRxVUo_4M8o" ,
2626 "yolox_x_dancetrack_ablation.pt" : "https://drive.google.com/uc?id=1ZKpYmFYCsRdXuOL60NRuc7VXAFYRskXB" ,
27+ "yolox_x_visdrone.pt" : "https://drive.google.com/uc?id=1ajehBs9enBHhuBqGIoQPGqkkzasE9d3o"
2728}
2829
2930
31+ def _coerce_torch_dtype (dtype , fallback : torch .Tensor ) -> torch .dtype :
32+ """Map YOLOX's dtype strings (e.g., 'torch.mps.FloatTensor') to real torch dtypes."""
33+ if isinstance (dtype , torch .dtype ):
34+ return dtype
35+ if isinstance (dtype , str ):
36+ lowered = dtype .lower ()
37+ if "bfloat16" in lowered :
38+ return torch .bfloat16
39+ if "float16" in lowered or "half" in lowered :
40+ return torch .float16
41+ # Default to the fallback tensor's dtype or float32.
42+ return fallback .dtype if isinstance (fallback , torch .Tensor ) else torch .float32
43+
44+
45+ def _patch_yolox_head_decode_outputs_for_mps () -> None :
46+ """Monkeypatch YOLOXHead.decode_outputs to work on MPS (avoids .type with dtype strings)."""
47+ try :
48+ from yolox .models .yolo_head import YOLOXHead
49+ from yolox .utils import meshgrid
50+ except Exception :
51+ return
52+
53+ if getattr (YOLOXHead , "_boxmot_mps_patched" , False ):
54+ return
55+
56+ def decode_outputs (self , outputs , dtype ):
57+ dtype = _coerce_torch_dtype (dtype , outputs )
58+ device = outputs .device
59+ grids = []
60+ strides = []
61+ for (hsize , wsize ), stride in zip (self .hw , self .strides ):
62+ yv , xv = meshgrid ([
63+ torch .arange (hsize , device = device ),
64+ torch .arange (wsize , device = device ),
65+ ])
66+ grid = torch .stack ((xv , yv ), 2 ).view (1 , - 1 , 2 )
67+ grids .append (grid )
68+ shape = grid .shape [:2 ]
69+ strides .append (torch .full ((* shape , 1 ), stride , device = device , dtype = grid .dtype ))
70+
71+ grids = torch .cat (grids , dim = 1 ).to (device = device , dtype = dtype )
72+ strides = torch .cat (strides , dim = 1 ).to (device = device , dtype = dtype )
73+
74+ outputs = outputs .clone ()
75+ outputs [..., :2 ] = (outputs [..., :2 ] + grids ) * strides
76+ outputs [..., 2 :4 ] = torch .exp (outputs [..., 2 :4 ]) * strides
77+ return outputs
78+
79+ YOLOXHead .decode_outputs = decode_outputs
80+ YOLOXHead ._boxmot_mps_patched = True
81+
82+
83+ _patch_yolox_head_decode_outputs_for_mps ()
84+
85+
3086class YoloXStrategy :
3187 """YOLOX strategy for use with Ultralytics predictor workflow."""
3288
@@ -135,9 +191,13 @@ def __init__(self, model, device, args):
135191 # Custom trained models (e.g., yolox_x_MOT17_ablation) use the base architecture
136192 if model_type == "yolox_n" :
137193 exp_name = "yolox_nano"
138- elif "_MOT" in model_type or "_dancetrack" in model_type :
139- # Extract base model: yolox_x_MOT17_ablation -> yolox_x
140- exp_name = model_type .split ("_MOT" )[0 ].split ("_dancetrack" )[0 ]
194+ elif "_MOT" in model_type or "_dancetrack" in model_type or "_visdrone" in model_type :
195+ # Extract base model: yolox_x_MOT17_ablation / yolox_x_visdrone -> yolox_x
196+ exp_name = (
197+ model_type .split ("_MOT" )[0 ]
198+ .split ("_dancetrack" )[0 ]
199+ .split ("_visdrone" )[0 ]
200+ )
141201 else :
142202 exp_name = model_type
143203 exp = get_exp (None , exp_name )
@@ -164,10 +224,13 @@ def __init__(self, model, device, args):
164224 self .device = device
165225 self .model = exp .get_model ()
166226 self .model .eval ()
167- self .model .load_state_dict (ckpt ["model" ])
168- self .model = fuse_model (self .model )
227+
228+ # folow official yolox loading procedure
229+ # https://github.com/Megvii-BaseDetection/YOLOX/blob/d872c71b/tools/eval.py#L148-L176
169230 self .model .to (self .device )
170231 self .model .eval ()
232+ self .model .load_state_dict (ckpt ["model" ])
233+ self .model = fuse_model (self .model )
171234 self .im_paths = []
172235 self ._preproc_data = []
173236
0 commit comments