-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
86 lines (67 loc) · 2.46 KB
/
utils.py
File metadata and controls
86 lines (67 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Utility functions for 5D YOLOv8 + GPS
"""
import os
import random
import torch
import numpy as np
import torch.optim as optim
from pathlib import Path
from typing import Dict
import config as cfg
from models import YOLO5D
def set_seeds(seed=None):
"""Set random seeds for reproducibility."""
seed = seed if seed is not None else cfg.SEED
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
def create_optimizer(model: YOLO5D) -> optim.Optimizer:
"""Create optimizer with different learning rates for different parts."""
backbone_params, new_params = [], []
# Separate parameters into groups
for name, param in model.named_parameters():
if any(k in name for k in ['adapt', 't_fuse', 'gps_head']):
new_params.append(param)
else:
backbone_params.append(param)
# Create optimizer with different learning rates
return optim.Adam([
{"params": backbone_params, "lr": cfg.LR_BACKBONE},
{"params": new_params, "lr": cfg.LR_NEW}
])
def save_checkpoint(model, epoch, optimizer, loss, is_best=False, checkpoint_dir="ckpts"):
"""Save model checkpoint."""
# Convert to Path object if string
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
# Create directory if it doesn't exist
checkpoint_dir.mkdir(exist_ok=True, parents=True)
# Determine checkpoint path
if is_best:
checkpoint_path = checkpoint_dir / "yolo5d_best.pt"
else:
checkpoint_path = checkpoint_dir / f"yolo5d_epoch{epoch:03d}.pt"
# Create checkpoint
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss
}
# Save checkpoint
torch.save(checkpoint, checkpoint_path)
return checkpoint_path
def load_checkpoint(model, checkpoint_path, optimizer=None):
"""Load model checkpoint."""
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=cfg.DEVICE)
# Load model state
model.load_state_dict(checkpoint["model_state_dict"])
# Load optimizer state if provided
if optimizer is not None and "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint