Skip to content

Commit 8db6a15

Browse files
imgsz for evaluation set in cli.py
1 parent 96c6d61 commit 8db6a15

File tree

2 files changed

+30
-117
lines changed

2 files changed

+30
-117
lines changed

boxmot/engine/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ def eval(ctx, detector, reid, tracker, yolo_model, reid_model, classes, **kwargs
431431
'classes': parse_classes(classes),
432432
'source': src,
433433
'benchmark': bench,
434-
'split': split}
434+
'split': split,
435+
'imgsz': [1088, 1920]}
435436
args = SimpleNamespace(**params)
436437
from boxmot.engine.evaluator import main as run_eval
437438
run_eval(args)

boxmot/engine/evaluator.py

Lines changed: 28 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@
3838
from ultralytics import YOLO
3939

4040
from boxmot.detectors import (
41-
default_imgsz,
4241
get_yolo_inferer,
4342
is_rtdetr_model,
4443
is_ultralytics_model,
4544
is_yolox_model,
45+
default_imgsz
4646
)
4747
from boxmot.utils.mot_utils import convert_to_mot_format, write_mot_results
4848
from boxmot.reid.core.auto_backend import ReidAutoBackend
@@ -116,114 +116,6 @@ def eval_init(args,
116116
args.project.mkdir(parents=True, exist_ok=True)
117117

118118

119-
def generate_dets_embs(args: argparse.Namespace, y: Path, source: Path) -> None:
120-
"""
121-
Generates detections and embeddings for the specified
122-
arguments, YOLO model and source.
123-
124-
Args:
125-
args (Namespace): Parsed command line arguments.
126-
y (Path): Path to the YOLO model file.
127-
source (Path): Path to the source directory.
128-
"""
129-
WEIGHTS.mkdir(parents=True, exist_ok=True)
130-
131-
args.imgsz = [1088, 1920]
132-
133-
seq_name = source.parent.name if source.name == "img1" else source.name
134-
135-
yolo = YOLO(
136-
y if is_ultralytics_model(y)
137-
else 'yolov8n.pt',
138-
)
139-
140-
results = yolo(
141-
source=source,
142-
conf=args.conf,
143-
iou=args.iou,
144-
agnostic_nms=args.agnostic_nms,
145-
stream=True,
146-
device=args.device,
147-
verbose=False,
148-
exist_ok=args.exist_ok,
149-
project=args.project,
150-
name=args.name,
151-
classes=args.classes,
152-
imgsz=args.imgsz,
153-
vid_stride=args.vid_stride,
154-
)
155-
156-
if not is_ultralytics_model(y):
157-
m = get_yolo_inferer(y)
158-
yolo_model = m(model=y, device=yolo.predictor.device,
159-
args=yolo.predictor.args)
160-
yolo.predictor.model = yolo_model
161-
162-
# If current model is YOLOX or RTDetr, change the preprocess and postprocess
163-
if is_yolox_model(y) or is_rtdetr_model(y):
164-
# add callback to save image paths for further processing
165-
yolo.add_callback("on_predict_batch_start",
166-
lambda p: yolo_model.update_im_paths(p))
167-
yolo.predictor.preprocess = (
168-
lambda im: yolo_model.preprocess(im=im))
169-
yolo.predictor.postprocess = (
170-
lambda preds, im, im0s:
171-
yolo_model.postprocess(preds=preds, im=im, im0s=im0s))
172-
173-
reids = []
174-
for r in args.reid_model:
175-
reid_model = ReidAutoBackend(weights=r,
176-
device=yolo.predictor.device,
177-
half=args.half).model
178-
reids.append(reid_model)
179-
embs_path = args.project / 'dets_n_embs' / y.stem / 'embs' / r.stem / (seq_name + '.txt')
180-
embs_path.parent.mkdir(parents=True, exist_ok=True)
181-
embs_path.touch(exist_ok=True)
182-
183-
if os.path.getsize(embs_path) > 0:
184-
open(embs_path, 'w').close()
185-
186-
yolo.predictor.custom_args = args
187-
188-
dets_path = args.project / 'dets_n_embs' / y.stem / 'dets' / (seq_name + '.txt')
189-
dets_path.parent.mkdir(parents=True, exist_ok=True)
190-
dets_path.touch(exist_ok=True)
191-
192-
if os.path.getsize(dets_path) > 0:
193-
open(dets_path, 'w').close()
194-
195-
with open(str(dets_path), 'ab+') as f:
196-
np.savetxt(f, [], fmt='%f', header=str(source))
197-
198-
for frame_idx, r in enumerate(tqdm(results, desc="Frames")):
199-
nr_dets = len(r.boxes)
200-
frame_idx = torch.full((1, 1), frame_idx + 1).repeat(nr_dets, 1)
201-
img = r.orig_img
202-
203-
dets = np.concatenate(
204-
[
205-
frame_idx,
206-
r.boxes.xyxy.to('cpu'),
207-
r.boxes.conf.unsqueeze(1).to('cpu'),
208-
r.boxes.cls.unsqueeze(1).to('cpu'),
209-
], axis=1
210-
)
211-
212-
# Keep boxes even if they extend outside the image; only drop invalid geometry
213-
boxes = r.boxes.xyxy.to('cpu').numpy()
214-
positive = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1])
215-
dets = dets[positive]
216-
217-
with open(str(dets_path), 'ab+') as f:
218-
np.savetxt(f, dets, fmt='%f')
219-
220-
for reid, reid_model_name in zip(reids, args.reid_model):
221-
embs = reid.get_features(dets[:, 1:5], img)
222-
embs_path = args.project / "dets_n_embs" / y.stem / 'embs' / reid_model_name.stem / (seq_name + '.txt')
223-
with open(str(embs_path), 'ab+') as f:
224-
np.savetxt(f, embs, fmt='%f')
225-
226-
227119
def parse_mot_results(results: str) -> dict:
228120
"""
229121
Extracts COMBINED HOTA, MOTA, IDF1, AssA, AssRe, IDSW, and IDs from MOTChallenge evaluation output.
@@ -398,8 +290,8 @@ def generate_dets_embs_batched(args: argparse.Namespace, y: Path, source_root: P
398290
auto_batch = bool(getattr(args, "auto_batch", True))
399291
resume = bool(getattr(args, "resume", True))
400292

401-
#if args.imgsz is None:
402-
args.imgsz = [1088, 1920]
293+
if args.imgsz is None:
294+
args.imgsz = default_imgsz()
403295

404296
yolo = YOLO(y if is_ultralytics_model(y) else 'yolov8n.pt')
405297

@@ -623,6 +515,22 @@ def setup_custom_model(predictor):
623515
confs = r.boxes.conf
624516
clss = r.boxes.cls
625517

518+
h, w = img.shape[:2]
519+
520+
x1, y1, x2, y2 = boxes.unbind(1)
521+
w_t = boxes.new_tensor(float(w))
522+
h_t = boxes.new_tensor(float(h))
523+
zero = torch.zeros_like(x1)
524+
525+
boxes_filter = (
526+
(torch.maximum(zero, x1) < torch.minimum(x2, w_t)) &
527+
(torch.maximum(zero, y1) < torch.minimum(y2, h_t))
528+
)
529+
530+
boxes = boxes[boxes_filter]
531+
confs = confs[boxes_filter]
532+
clss = clss[boxes_filter]
533+
626534
nr = int(boxes.shape[0])
627535
if nr == 0:
628536
pbar.update(1)
@@ -642,14 +550,15 @@ def setup_custom_model(predictor):
642550

643551
widths = boxes[:, 2] - boxes[:, 0]
644552
heights = boxes[:, 3] - boxes[:, 1]
645-
min_wh = (widths >= 10.0) & (heights >= 10.0)
646-
if not bool(min_wh.any()):
553+
areas = widths * heights
554+
valid_area = areas >= 20.0
555+
if not bool(valid_area.any()):
647556
pbar.update(1)
648557
continue
649558

650-
boxes = boxes[min_wh]
651-
confs = confs[min_wh]
652-
clss = clss[min_wh]
559+
boxes = boxes[valid_area]
560+
confs = confs[valid_area]
561+
clss = clss[valid_area]
653562

654563
frame_col = torch.full(
655564
(boxes.shape[0], 1),
@@ -663,6 +572,8 @@ def setup_custom_model(predictor):
663572
)
664573

665574
dets_np = dets_t.detach().float().cpu().numpy()
575+
dets_np[:, 1:5] = np.rint(dets_np[:, 1:5]) # round xyxy only
576+
666577
np.savetxt(det_fhs[seq_name], dets_np, fmt="%f")
667578

668579
det_boxes_np = dets_np[:, 1:5]
@@ -1132,6 +1043,7 @@ def main(args):
11321043
LOGGER.opt(colors=True).info(f"<bold>ReID:</bold> <cyan>{args.reid_model[0]}</cyan>")
11331044
LOGGER.opt(colors=True).info(f"<bold>Tracker:</bold> <cyan>{args.tracking_method}</cyan>")
11341045
LOGGER.opt(colors=True).info(f"<bold>Benchmark:</bold> <cyan>{args.source}</cyan>")
1046+
LOGGER.opt(colors=True).info(f"<bold>Image size:</bold> <cyan>{getattr(args, 'imgsz', None)}</cyan>")
11351047
LOGGER.opt(colors=True).info("<blue>" + "="*60 + "</blue>")
11361048

11371049
# Step 1: Download TrackEval

0 commit comments

Comments
 (0)