Skip to content

Commit 73fa024

Browse files
authored
Merge pull request #153 from wenh06/master
add compatibility for eval_model for multi-GPU, following the advice of @xungeer29
2 parents 25dd861 + 9c8bcb6 commit 73fa024

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,10 @@ def burnin_schedule(i):
417417
else:
418418
eval_model = Yolov4(cfg.pretrained, n_classes=cfg.classes, inference=True)
419419
# eval_model = Yolov4(yolov4conv137weight=None, n_classes=config.classes, inference=True)
420-
eval_model.load_state_dict(model.state_dict())
420+
if torch.cuda.device_count() > 1:
421+
eval_model.load_state_dict(model.module.state_dict())
422+
else:
423+
eval_model.load_state_dict(model.state_dict())
421424
eval_model.to(device)
422425
evaluator = evaluate(eval_model, val_loader, config, device)
423426
del eval_model

0 commit comments

Comments
 (0)