Skip to content

Commit 02f3dad

Browse files
Feat/enable multi class evaluation (#2199)
1 parent eaabda4 commit 02f3dad

File tree

16 files changed

+572
-108
lines changed

16 files changed

+572
-108
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
run: |
7979
source .venv/bin/activate
8080
# reuse first set of generated det and prod
81-
boxmot tune --yolo-model yolov8n.pt --reid-model osnet_x0_25_msmt17.pt --n-trials 3 --tracking-method strongsort --source ./assets/MOT17-mini/train --ci
81+
boxmot tune --yolo-model yolov8n.pt --reid-model osnet_x0_25_msmt17.pt --n-trials 3 --tracking-method strongsort --source ./assets/MOT17-mini/train --ci --classes 0
8282
mot-metrics-benchmark:
8383
runs-on: ${{ matrix.os }}
8484
strategy:
@@ -118,7 +118,7 @@ jobs:
118118
source .venv/bin/activate
119119
echo "Format,Status❔,HOTA,MOTA,IDF1" > results.csv
120120
for tracker in $TRACKERS; do
121-
if boxmot eval --yolo-model yolov8n.pt --reid-model osnet_x0_25_msmt17.pt --tracking-method $tracker --verbose --source ./assets/MOT17-mini/train --ci; then
121+
if boxmot eval --yolo-model yolov8n.pt --reid-model osnet_x0_25_msmt17.pt --tracking-method $tracker --verbose --source ./assets/MOT17-mini/train --ci --classes 0; then
122122
STATUS="✅"
123123
else
124124
STATUS="❌"

boxmot/configs/datasets/MOT17-ablation.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
download:
33
runs_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v12.0.7/runs.zip"
44
dataset_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v13.0.9/MOT17-ablation.zip"
5-
dataset_dest: "boxmot/engine/trackeval/MOT17-ablation.zip"
65

76
benchmark:
8-
name: "MOT17-ablation"
7+
source: "boxmot/engine/trackeval/MOT17-ablation"
98
split: "train"
9+
classes: "person"
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# https://motchallenge.net/data/MOT20/
22
download:
3-
runs_url: null
3+
runs_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v12.0.7/runs.zip"
44
dataset_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v13.0.9/MOT20-ablation.zip"
5-
dataset_dest: "boxmot/engine/trackeval/MOT20-ablation.zip"
65

76
benchmark:
8-
name: "MOT20-ablation"
7+
source: "boxmot/engine/trackeval/MOT20-ablation"
98
split: "train"
9+
classes: "person"
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# https://motchallenge.net/data/MOT20/
1+
# https://github.com/SportsMOT/SportsMOT
22
download:
3-
runs_url: null
4-
dataset_url: "https://onedrive.live.com/?redeem=aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdGplTHE3WW5ZR1JnUVJybXFHcjRCLWsteHNDP2U9N1BuZFU4&cid=91819DD8AE2EDED8&id=91819DD8AE2EDED8%21132&parId=91819DD8AE2EDED8%21129&o=OneUp"
5-
dataset_dest: "boxmot/engine/trackeval/SportsMOT.zip"
3+
runs_url: ""
4+
dataset_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v13.0.9/SportsMOT.zip"
65

76
benchmark:
8-
name: "SportsMOT"
9-
split: "train"
7+
source: "boxmot/engine/trackeval/SportsMOT"
8+
split: "val"
9+
classes: "person"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# https://motchallenge.net/data/MOT17/
2+
download:
3+
runs_url: ""
4+
dataset_url: ""
5+
6+
benchmark:
7+
source: "assets/MOT17-mini"
8+
split: "train"
9+
classes: "person"
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# https://huggingface.co/datasets/noahcao/dancetrack/tree/main
22
download:
33
runs_url: null
4-
dataset_url: "https://huggingface.co/datasets/noahcao/dancetrack/resolve/main/val.zip?download=true"
5-
dataset_dest: "boxmot/engine/trackeval/dancetrack-ablation.zip"
4+
dataset_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v13.0.9/dancetrack-ablation.zip"
65

76
benchmark:
8-
name: "dancetrack-ablation"
7+
source: "boxmot/engine/trackeval/dancetrack-ablation"
98
split: "val"
9+
classes: "person"
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# https://motchallenge.net/data/MOT20/
1+
# https://github.com/VisDrone/VisDrone-Dataset
22
download:
3-
runs_url: null
4-
dataset_url: "https://drive.google.com/uc?export=download&id=1rqnKe9IgU_crMaxRoel9_nuUsMEBBVQu"
5-
dataset_dest: "boxmot/engine/trackeval/vizdrone-ablation.zip"
3+
runs_url: ""
4+
dataset_url: "https://github.com/mikel-brostrom/boxmot/releases/download/v13.0.9/vizdrone-ablation.zip"
65

76
benchmark:
8-
name: "VisDrone2019-MOT-val"
9-
split: ""
7+
source: "boxmot/engine/trackeval/vizdrone-ablation"
8+
split: "train"
9+
classes: "person car truck bus van"

boxmot/engine/cli.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,29 @@ def core_options(func):
9999
return func
100100

101101

102+
def parse_classes(classes_input):
103+
"""
104+
Parse classes input which can be a tuple of ints (from multiple=True),
105+
a string (comma/space separated), or None.
106+
Returns a list of integers or None.
107+
"""
108+
if classes_input is None:
109+
return None
110+
111+
if isinstance(classes_input, (list, tuple)):
112+
# If it's already a list/tuple of ints (from multiple=True)
113+
if not classes_input:
114+
return None
115+
return list(classes_input)
116+
117+
if isinstance(classes_input, str):
118+
# Handle string input: "0,1" or "0 1"
119+
classes_input = classes_input.replace(',', ' ')
120+
return [int(x) for x in classes_input.split()]
121+
122+
return [int(classes_input)]
123+
124+
102125
def singular_model_options(func):
103126
options = [
104127
click.option('--yolo-model', type=Path,
@@ -107,8 +130,8 @@ def singular_model_options(func):
107130
click.option('--reid-model', type=Path,
108131
default=WEIGHTS / 'osnet_x0_25_msmt17.pt',
109132
help='path to ReID model weights'),
110-
click.option('--classes', type=int, multiple=True,
111-
help='filter by class indices')
133+
click.option('--classes', type=str, default=None,
134+
help='filter by class indices, e.g. 0 or "0,1"')
112135
]
113136
for opt in reversed(options):
114137
func = opt(func)
@@ -123,8 +146,8 @@ def plural_model_options(func):
123146
click.option('--reid-model', type=Path, multiple=True,
124147
default=[WEIGHTS / 'osnet_x0_25_msmt17.pt'],
125148
help='one or more ReID model weights'),
126-
click.option('--classes', type=int, multiple=True,
127-
default=[0], help='filter by class indices')
149+
click.option('--classes', type=str, default=None,
150+
help='filter by class indices, e.g. 0 or "0,1"')
128151
]
129152
for opt in reversed(options):
130153
func = opt(func)
@@ -291,7 +314,7 @@ def track(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwarg
291314
params = {**kwargs,
292315
'yolo_model': yolo_model,
293316
'reid_model': reid_model,
294-
'classes': list(classes) if classes else None,
317+
'classes': parse_classes(classes),
295318
'source': src,
296319
'benchmark': bench,
297320
'split': split}
@@ -324,7 +347,7 @@ def generate(ctx, detector, reid, yolo_model, reid_model, classes, **kwargs):
324347
params = {**kwargs,
325348
'yolo_model': list(yolo_model),
326349
'reid_model': list(reid_model),
327-
'classes': list(classes),
350+
'classes': parse_classes(classes),
328351
'source': src,
329352
'benchmark': bench,
330353
'split': split}
@@ -360,7 +383,7 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
360383
params = {**kwargs,
361384
'yolo_model': list(yolo_model),
362385
'reid_model': list(reid_model),
363-
'classes': [0],
386+
'classes': parse_classes(classes),
364387
'source': src,
365388
'benchmark': bench,
366389
'split': split}
@@ -397,7 +420,7 @@ def tune(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
397420
params = {**kwargs,
398421
'yolo_model': list(yolo_model),
399422
'reid_model': list(reid_model),
400-
'classes': list(classes),
423+
'classes': parse_classes(classes),
401424
'source': src,
402425
'benchmark': bench,
403426
'split': split}

0 commit comments

Comments
 (0)