11import argparse
22from PIL import ImageFile
3+ import math
34import torch
45import torch .nn .functional as F
56import torch .utils .data
1213from models import *
1314ImageFile .LOAD_TRUNCATED_IMAGES = True
1415
16+
1517parser = argparse .ArgumentParser ()
1618parser .add_argument ("--nepochs" , type = int , default = 100 , help = "Number of epochs" )
1719parser .add_argument ("--batchsize" , type = int , default = 32 , help = "Batch size" )
2527args = parser .parse_args ()
2628args .nworkers = torch .multiprocessing .cpu_count () // 2 if args .nworkers == 0 else args .nworkers
2729
30+
2831class CocoWrapper (torch .utils .data .Dataset ):
2932 def __init__ (self , root , annFile , transforms = []):
3033 super ().__init__ ()
@@ -46,6 +49,7 @@ def __getitem__(self, index):
4649 target = torch .cat ([boxes ,classes ], - 1 )
4750 return img , target
4851
52+
4953def CocoCollator (batch ):
5054 imgs , targets = zip (* batch )
5155 N = max (t .shape [0 ] for t in targets )
@@ -57,37 +61,29 @@ def CocoCollator(batch):
5761 targets = torch .stack (targets , 0 )
5862 return imgs , targets
5963
60- def createOptimizer (self : torch .nn .Module , momentum = 0.9 , lr = 0.001 , decay = 0.0001 ):
61- bn = tuple (v for k , v in nn .__dict__ .items () if "Norm" in k ) # normalization layers
62- g = [], [], []
63- for module_name , module in self .named_modules ():
64- for param_name , param in module .named_parameters (recurse = False ):
65- fullname = f"{ module_name } .{ param_name } " if module_name else param_name
66- if "bias" in fullname :
67- g [2 ].append (param ) # bias (no decay)
68- elif isinstance (module , bn ):
69- g [1 ].append (param ) # weight (no decay)
70- else :
71- g [0 ].append (param ) # weight (with decay)
72- num_non_decayed_biases = sum (p .numel () for p in g [2 ])
73- num_non_decayed_weights = sum (p .numel () for p in g [1 ])
74- num_decayed_weights = sum (p .numel () for p in g [0 ])
75- print (f"num non-decayed biases : { len (g [2 ])} , with { num_non_decayed_biases } parameters" )
76- print (f"num non-decayed weights : { len (g [1 ])} , with { num_non_decayed_weights } parameters" )
77- print (f"num decayed weights : { len (g [0 ])} , with { num_decayed_weights } parameters" )
78- assert num_non_decayed_biases + num_non_decayed_weights + num_decayed_weights == sum (p .numel () for p in self .parameters () if p .requires_grad )
79- optimizer = torch .optim .SGD (g [2 ], lr = lr , momentum = momentum , nesterov = True )
80- # optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), fused=True)
81- optimizer .add_param_group ({"params" : g [0 ], "weight_decay" : decay }) # add g0 with weight_decay
82- optimizer .add_param_group ({"params" : g [1 ], "weight_decay" : 0.0 }) # add g1 (BatchNorm2d weights)
64+
65+ def createOptimizer (module : torch .nn .Module , momentum = 0.9 , lr = 0.001 , decay = 0.01 ):
66+ wd_params = [p for p in module .parameters () if p .dim () >= 2 ]
67+ no_wd_params = [p for p in module .parameters () if p .dim () < 2 ]
68+ optim_groups = [{'params' : wd_params , 'weight_decay' : decay },
69+ {'params' : no_wd_params , 'weight_decay' : 0.0 }]
70+ optimizer = torch .optim .AdamW (optim_groups , lr = lr , betas = (momentum , 0.99 ), fused = True )
8371 return optimizer
8472
73+
74+ def createScheduler (optimizer , total_steps , warmup_steps ):
75+ t0 = warmup_steps
76+ t1 = total_steps
77+ def fn0 (t ): return math .sin (t * math .pi / (2 * t0 ))** 2
78+ def fn1 (t ): return math .cos ((t - t0 )* math .pi / (2 * (t1 - t0 )))** 2
79+ return torch .optim .lr_scheduler .LambdaLR (optimizer , lambda t : fn0 (t ) if t < t0 else fn1 (t ))
80+
81+
8582class LitModule (pl .LightningModule ):
86- def __init__ (self , net , nc , nsteps ):
83+ def __init__ (self , net , nc ):
8784 super ().__init__ ()
8885 self .net = net
8986 self .nc = nc
90- self .nsteps = nsteps
9187
9288 def training_step (self , batch , batch_idx ):
9389 return self .step (batch , batch_idx , self .trainer .num_training_batches , is_training = True )
@@ -98,26 +94,25 @@ def validation_step(self, batch, batch_idx):
9894 def step (self , batch , batch_idx , nbatches , is_training ):
9995 imgs , targets = batch
10096 preds , losses = self .net (imgs , targets )
101- loss = 7.5 * losses ['iou' ] + 0.5 * losses ['cls' ] + 0.5 * losses ['obj' ]
102- # loss = 7.5 * losses['iou'] + 0.5 * losses['cls'] + 1.5 * losses['dfl']
97+ loss = 7.5 * losses ['iou' ] + 0.5 * losses ['cls' ] + 0.5 * losses ['obj' ] + (losses ['dfl' ] if 'dfl' in losses else 0 )
10398
10499 label = "train" if is_training else "val"
105- self .log ("loss/obj /" + label , losses [ 'obj' ] .item (), logger = False , prog_bar = False , on_step = True )
106- # self.log("loss/dfl /" + label, losses['dfl '].item(), logger=False, prog_bar=False, on_step=True)
107- self .log ("loss/cls /" + label , losses ['cls ' ].item (), logger = False , prog_bar = False , on_step = True )
108- self .log ("loss/iou /" + label , losses ['iou ' ].item (), logger = False , prog_bar = False , on_step = True )
109- self .log ("loss/sum /" + label , loss .item (), logger = False , prog_bar = True , on_step = True , on_epoch = True )
100+ self .log ("loss/sum /" + label , loss .item (), logger = False , prog_bar = True , on_step = True , on_epoch = True )
101+ if 'obj' in losses : self .log ("loss/obj /" + label , losses ['obj ' ].item (), logger = False , prog_bar = False , on_step = True )
102+ if 'dfl' in losses : self .log ("loss/dfl /" + label , losses ['dfl ' ].item (), logger = False , prog_bar = False , on_step = True )
103+ if 'cls' in losses : self .log ("loss/cls /" + label , losses ['cls ' ].item (), logger = False , prog_bar = False , on_step = True )
104+ if 'iou' in losses : self .log ("loss/iou /" + label , losses [ 'iou' ] .item (), logger = False , prog_bar = False , on_step = True )
110105
111106 if self .trainer .is_global_zero :
112107 summary = self .logger .experiment
113108 epoch = self .current_epoch
114109 totalBatch = (epoch + batch_idx / nbatches ) * 1000
115110
116- summary .add_scalars ("loss/obj " , {label : losses [ 'obj' ] .item ()}, totalBatch )
117- # summary.add_scalars("loss/dfl ", {label: losses['dfl '].item()}, totalBatch)
118- summary .add_scalars ("loss/cls " , {label : losses ['cls ' ].item ()}, totalBatch )
119- summary .add_scalars ("loss/iou " , {label : losses ['iou ' ].item ()}, totalBatch )
120- summary .add_scalars ("loss/sum " , {label : loss .item ()}, totalBatch )
111+ summary .add_scalars ("loss/sum " , {label : loss .item ()}, totalBatch )
112+ if 'obj' in losses : summary .add_scalars ("loss/obj " , {label : losses ['obj ' ].item ()}, totalBatch )
113+ if 'dfl' in losses : summary .add_scalars ("loss/dfl " , {label : losses ['dfl ' ].item ()}, totalBatch )
114+ if 'cls' in losses : summary .add_scalars ("loss/cls " , {label : losses ['cls ' ].item ()}, totalBatch )
115+ if 'iou' in losses : summary .add_scalars ("loss/iou " , {label : losses [ 'iou' ] .item ()}, totalBatch )
121116
122117 if batch_idx % 50 == 0 :
123118 with torch .no_grad ():
@@ -133,11 +128,9 @@ def step(self, batch, batch_idx, nbatches, is_training):
133128 return loss
134129
135130 def configure_optimizers (self ):
136- optimizer = createOptimizer (self , lr = args .lr )
137- scheduler = torch .optim .lr_scheduler .OneCycleLR (optimizer ,
138- max_lr = [g ['lr' ] for g in optimizer .param_groups ],
139- total_steps = self .nsteps ,
140- pct_start = args .nwarmup / self .nsteps )
131+ total_steps = self .trainer .estimated_stepping_batches
132+ optimizer = createOptimizer (self , lr = args .lr )
133+ scheduler = createScheduler (optimizer , total_steps , args .nwarmup )
141134 return {'optimizer' : optimizer , 'lr_scheduler' : {'scheduler' : scheduler , 'interval' : "step" , "frequency" : 1 }}
142135
143136torch .set_float32_matmul_precision ('medium' )
@@ -156,11 +149,9 @@ def configure_optimizers(self):
156149nclasses = len (valset .names )
157150trainLoader = torch .utils .data .DataLoader (trainset , batch_size = args .batchsize , shuffle = True , collate_fn = CocoCollator , num_workers = args .nworkers )
158151valLoader = torch .utils .data .DataLoader (valset , batch_size = args .batchsize , collate_fn = CocoCollator , num_workers = args .nworkers )
159- nsteps = len (trainLoader ) * args .nepochs
160152
161- net = Yolov3 (nclasses , spp = True )
162- init_batchnorms (net )
163- net = LitModule (net , nclasses , nsteps )
153+ net = Yolov26 ('n' , nclasses )
154+ net = LitModule (net , nclasses )
164155
165156trainer = pl .Trainer (max_epochs = args .nepochs ,
166157 accelerator = 'gpu' ,
0 commit comments