Skip to content

Commit cddc9ee

Browse files
author
pfeatherstone
committed
- simplified createOptimizer()
- replaced OneCycle with simple createScheduler() - trying Yolov26
1 parent d3856ae commit cddc9ee

File tree

1 file changed

+37
-46
lines changed

1 file changed

+37
-46
lines changed

src/train_coco.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
from PIL import ImageFile
3+
import math
34
import torch
45
import torch.nn.functional as F
56
import torch.utils.data
@@ -12,6 +13,7 @@
1213
from models import *
1314
ImageFile.LOAD_TRUNCATED_IMAGES = True
1415

16+
1517
parser = argparse.ArgumentParser()
1618
parser.add_argument("--nepochs", type=int, default=100, help="Number of epochs")
1719
parser.add_argument("--batchsize", type=int, default=32, help="Batch size")
@@ -25,6 +27,7 @@
2527
args = parser.parse_args()
2628
args.nworkers = torch.multiprocessing.cpu_count() // 2 if args.nworkers == 0 else args.nworkers
2729

30+
2831
class CocoWrapper(torch.utils.data.Dataset):
2932
def __init__(self, root, annFile, transforms=[]):
3033
super().__init__()
@@ -46,6 +49,7 @@ def __getitem__(self, index):
4649
target = torch.cat([boxes,classes], -1)
4750
return img, target
4851

52+
4953
def CocoCollator(batch):
5054
imgs, targets = zip(*batch)
5155
N = max(t.shape[0] for t in targets)
@@ -57,37 +61,29 @@ def CocoCollator(batch):
5761
targets = torch.stack(targets, 0)
5862
return imgs, targets
5963

60-
def createOptimizer(self: torch.nn.Module, momentum=0.9, lr=0.001, decay=0.0001):
61-
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers
62-
g = [], [], []
63-
for module_name, module in self.named_modules():
64-
for param_name, param in module.named_parameters(recurse=False):
65-
fullname = f"{module_name}.{param_name}" if module_name else param_name
66-
if "bias" in fullname:
67-
g[2].append(param) # bias (no decay)
68-
elif isinstance(module, bn):
69-
g[1].append(param) # weight (no decay)
70-
else:
71-
g[0].append(param) # weight (with decay)
72-
num_non_decayed_biases = sum(p.numel() for p in g[2])
73-
num_non_decayed_weights = sum(p.numel() for p in g[1])
74-
num_decayed_weights = sum(p.numel() for p in g[0])
75-
print(f"num non-decayed biases : {len(g[2])}, with {num_non_decayed_biases} parameters")
76-
print(f"num non-decayed weights : {len(g[1])}, with {num_non_decayed_weights} parameters")
77-
print(f"num decayed weights : {len(g[0])}, with {num_decayed_weights} parameters")
78-
assert num_non_decayed_biases + num_non_decayed_weights + num_decayed_weights == sum(p.numel() for p in self.parameters() if p.requires_grad)
79-
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
80-
# optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), fused=True)
81-
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
82-
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
64+
65+
def createOptimizer(module: torch.nn.Module, momentum=0.9, lr=0.001, decay=0.01):
66+
wd_params = [p for p in module.parameters() if p.dim() >= 2]
67+
no_wd_params = [p for p in module.parameters() if p.dim() < 2]
68+
optim_groups = [{'params': wd_params, 'weight_decay': decay},
69+
{'params': no_wd_params, 'weight_decay': 0.0}]
70+
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(momentum, 0.99), fused=True)
8371
return optimizer
8472

73+
74+
def createScheduler(optimizer, total_steps, warmup_steps):
75+
t0 = warmup_steps
76+
t1 = total_steps
77+
def fn0(t): return math.sin(t*math.pi/(2*t0))**2
78+
def fn1(t): return math.cos((t-t0)*math.pi/(2*(t1-t0)))**2
79+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: fn0(t) if t < t0 else fn1(t))
80+
81+
8582
class LitModule(pl.LightningModule):
86-
def __init__(self, net, nc, nsteps):
83+
def __init__(self, net, nc):
8784
super().__init__()
8885
self.net = net
8986
self.nc = nc
90-
self.nsteps = nsteps
9187

9288
def training_step(self, batch, batch_idx):
9389
return self.step(batch, batch_idx, self.trainer.num_training_batches, is_training=True)
@@ -98,26 +94,25 @@ def validation_step(self, batch, batch_idx):
9894
def step(self, batch, batch_idx, nbatches, is_training):
9995
imgs, targets = batch
10096
preds, losses = self.net(imgs, targets)
101-
loss = 7.5 * losses['iou'] + 0.5 * losses['cls'] + 0.5 * losses['obj']
102-
# loss = 7.5 * losses['iou'] + 0.5 * losses['cls'] + 1.5 * losses['dfl']
97+
loss = 7.5 * losses['iou'] + 0.5 * losses['cls'] + 0.5 * losses['obj'] + (losses['dfl'] if 'dfl' in losses else 0)
10398

10499
label = "train" if is_training else "val"
105-
self.log("loss/obj/" + label, losses['obj'].item(), logger=False, prog_bar=False, on_step=True)
106-
# self.log("loss/dfl/" + label, losses['dfl'].item(), logger=False, prog_bar=False, on_step=True)
107-
self.log("loss/cls/" + label, losses['cls'].item(), logger=False, prog_bar=False, on_step=True)
108-
self.log("loss/iou/" + label, losses['iou'].item(), logger=False, prog_bar=False, on_step=True)
109-
self.log("loss/sum/" + label, loss.item(), logger=False, prog_bar=True, on_step=True, on_epoch=True)
100+
self.log("loss/sum/" + label, loss.item(), logger=False, prog_bar=True, on_step=True, on_epoch=True)
101+
if 'obj' in losses: self.log("loss/obj/" + label, losses['obj'].item(), logger=False, prog_bar=False, on_step=True)
102+
if 'dfl' in losses: self.log("loss/dfl/" + label, losses['dfl'].item(), logger=False, prog_bar=False, on_step=True)
103+
if 'cls' in losses: self.log("loss/cls/" + label, losses['cls'].item(), logger=False, prog_bar=False, on_step=True)
104+
if 'iou' in losses: self.log("loss/iou/" + label, losses['iou'].item(), logger=False, prog_bar=False, on_step=True)
110105

111106
if self.trainer.is_global_zero:
112107
summary = self.logger.experiment
113108
epoch = self.current_epoch
114109
totalBatch = (epoch + batch_idx / nbatches) * 1000
115110

116-
summary.add_scalars("loss/obj", {label: losses['obj'].item()}, totalBatch)
117-
# summary.add_scalars("loss/dfl", {label: losses['dfl'].item()}, totalBatch)
118-
summary.add_scalars("loss/cls", {label: losses['cls'].item()}, totalBatch)
119-
summary.add_scalars("loss/iou", {label: losses['iou'].item()}, totalBatch)
120-
summary.add_scalars("loss/sum", {label: loss.item()}, totalBatch)
111+
summary.add_scalars("loss/sum", {label: loss.item()}, totalBatch)
112+
if 'obj' in losses: summary.add_scalars("loss/obj", {label: losses['obj'].item()}, totalBatch)
113+
if 'dfl' in losses: summary.add_scalars("loss/dfl", {label: losses['dfl'].item()}, totalBatch)
114+
if 'cls' in losses: summary.add_scalars("loss/cls", {label: losses['cls'].item()}, totalBatch)
115+
if 'iou' in losses: summary.add_scalars("loss/iou", {label: losses['iou'].item()}, totalBatch)
121116

122117
if batch_idx % 50 == 0:
123118
with torch.no_grad():
@@ -133,11 +128,9 @@ def step(self, batch, batch_idx, nbatches, is_training):
133128
return loss
134129

135130
def configure_optimizers(self):
136-
optimizer = createOptimizer(self, lr=args.lr)
137-
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
138-
max_lr=[g['lr'] for g in optimizer.param_groups],
139-
total_steps=self.nsteps,
140-
pct_start=args.nwarmup/self.nsteps)
131+
total_steps = self.trainer.estimated_stepping_batches
132+
optimizer = createOptimizer(self, lr=args.lr)
133+
scheduler = createScheduler(optimizer, total_steps, args.nwarmup)
141134
return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': "step", "frequency": 1}}
142135

143136
torch.set_float32_matmul_precision('medium')
@@ -156,11 +149,9 @@ def configure_optimizers(self):
156149
nclasses = len(valset.names)
157150
trainLoader = torch.utils.data.DataLoader(trainset, batch_size=args.batchsize, shuffle=True, collate_fn=CocoCollator, num_workers=args.nworkers)
158151
valLoader = torch.utils.data.DataLoader(valset, batch_size=args.batchsize, collate_fn=CocoCollator, num_workers=args.nworkers)
159-
nsteps = len(trainLoader) * args.nepochs
160152

161-
net = Yolov3(nclasses, spp=True)
162-
init_batchnorms(net)
163-
net = LitModule(net, nclasses, nsteps)
153+
net = Yolov26('n', nclasses)
154+
net = LitModule(net, nclasses)
164155

165156
trainer = pl.Trainer(max_epochs=args.nepochs,
166157
accelerator='gpu',

0 commit comments

Comments
 (0)