1818import tqdm
1919
2020import pytorch_lightning as pl
21- from pytorch_lightning .callbacks import Callback , ModelCheckpoint
21+ from pytorch_lightning .callbacks import ModelCheckpoint
2222
23- # from pytorch_lightning.loggers import Logger,
24- from pytorch_lightning .utilities import rank_zero_only
2523import torch .optim as optim
2624import torch
2725import torch .nn as nn
3230import trackertraincode .train as train
3331import trackertraincode .pipelines
3432
35- from trackertraincode .neuralnets .io import complement_lightning_checkpoint
36- from scripts .export_model import convert_posemodel_onnx
3733from trackertraincode .datasets .batch import Batch
3834from trackertraincode .pipelines import Tag
3935
@@ -161,11 +157,6 @@ def create_optimizer(net, args: MyArgs):
161157 return optimizer , scheduler
162158
163159
164- class SaveBestSpec (NamedTuple ):
165- weights : List [float ]
166- names : List [str ]
167-
168-
169160def setup_losses (args : MyArgs , net ):
170161 C = train .Criterion
171162 cregularize = [
@@ -259,9 +250,7 @@ def wrapped(step):
259250 ),
260251 }
261252
262- savebest = SaveBestSpec ([1.0 , 1.0 , 1.0 ], ["rot" , "xy" , "sz" ])
263-
264- return train_criterions , test_criterions , savebest
253+ return train_criterions , test_criterions
265254
266255
267256def create_net (args : MyArgs ):
@@ -281,7 +270,7 @@ def __init__(self, args: MyArgs):
281270 super ().__init__ ()
282271 self ._args = args
283272 self ._model = create_net (args )
284- train_criterions , test_criterions , savebest = setup_losses (args , self ._model )
273+ train_criterions , test_criterions = setup_losses (args , self ._model )
285274 self ._train_criterions = train_criterions
286275 self ._test_criterions = test_criterions
287276
@@ -315,120 +304,6 @@ def model(self):
315304 return self ._model
316305
317306
318- class SwaCallback (Callback ):
319- def __init__ (self , start_epoch ):
320- super ().__init__ ()
321- self ._swa_model : optim .swa_utils .AveragedModel | None = None
322- self ._start_epoch = start_epoch
323-
324- @property
325- def swa_model (self ):
326- return self ._swa_model .module
327-
328- def on_train_start (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
329- assert isinstance (pl_module , LitModel )
330- self ._swa_model = optim .swa_utils .AveragedModel (pl_module .model , device = "cpu" , use_buffers = True )
331-
332- def on_train_epoch_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
333- assert isinstance (pl_module , LitModel )
334- if trainer .current_epoch > self ._start_epoch :
335- self ._swa_model .update_parameters (pl_module .model )
336-
337- def on_train_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
338- assert self ._swa_model is not None
339- swa_filename = join (trainer .default_root_dir , f"swa.ckpt" )
340- models .save_model (self ._swa_model .module , swa_filename )
341-
342-
343- class MetricsGraphing (Callback ):
344- def __init__ (self ):
345- super ().__init__ ()
346- self ._visu : train .TrainHistoryPlotter | None = None
347- self ._metrics_accumulator = defaultdict (list )
348-
349- def on_train_start (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
350- assert self ._visu is None
351- self ._visu = train .TrainHistoryPlotter (save_filename = join (trainer .default_root_dir , "train.pdf" ))
352-
353- def on_train_batch_end (
354- self , trainer : pl .Trainer , pl_module : pl .LightningModule , outputs : Any , batch : Any , batch_idx : int
355- ):
356- mt_losses : dict [str , torch .Tensor ] = outputs ["mt_losses" ]
357- for k , v in mt_losses .items ():
358- self ._visu .add_train_point (trainer .current_epoch , batch_idx , k , v .numpy ())
359- self ._visu .add_train_point (trainer .current_epoch , batch_idx , "loss" , outputs ["loss" ].detach ().cpu ().numpy ())
360-
361- def on_train_epoch_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
362- if trainer .lr_scheduler_configs : # scheduler is not None:
363- scheduler = next (
364- iter (trainer .lr_scheduler_configs )
365- ).scheduler # Pick the first scheduler (and there should only be one)
366- last_lr = next (iter (scheduler .get_last_lr ())) # LR from the first parameter group
367- self ._visu .add_test_point (trainer .current_epoch , "lr" , last_lr )
368-
369- self ._visu .summarize_train_values ()
370-
371- def on_validation_start (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
372- self ._metrics_accumulator = defaultdict (list )
373-
374- def on_validation_batch_end (
375- self ,
376- trainer : pl .Trainer ,
377- pl_module : pl .LightningModule ,
378- outputs : list [train .LossVal ],
379- batch : Any ,
380- batch_idx : int ,
381- dataloader_idx : int = 0 ,
382- ) -> None :
383- for val in outputs :
384- self ._metrics_accumulator [val .name ].append (val .val )
385-
386- def on_validation_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
387- if self ._visu is None :
388- return
389- for k , v in self ._metrics_accumulator .items ():
390- self ._visu .add_test_point (trainer .current_epoch - 1 , k , torch .cat (v ).mean ().cpu ().numpy ())
391- if trainer .current_epoch > 0 :
392- self ._visu .update_graph ()
393-
394-
395- class SimpleProgressBar (Callback ):
396- """Creates progress bars for total training time and progress of per epoch."""
397-
398- def __init__ (self , batchsize : int ):
399- super ().__init__ ()
400- self ._bar : tqdm .tqdm | None = None
401- self ._epoch_bar : tqdm .tqdm | None = None
402- self ._batchsize = batchsize
403-
404- def on_train_start (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
405- self ._bar = tqdm .tqdm (total = trainer .max_epochs , desc = 'Training' , position = 0 )
406- self ._epoch_bar = tqdm .tqdm (total = trainer .num_training_batches * self ._batchsize , desc = "Epoch" , position = 1 )
407-
408- def on_train_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
409- self ._bar .close ()
410- self ._epoch_bar .close ()
411- self ._bar = None
412- self ._epoch_bar = None
413-
414- def on_train_epoch_start (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
415- self ._epoch_bar .reset (self ._epoch_bar .total )
416-
417- def on_train_epoch_end (self , trainer : pl .Trainer , pl_module : pl .LightningModule ) -> None :
418- self ._bar .update (1 )
419-
420- def on_train_batch_end (
421- self ,
422- trainer : pl .Trainer ,
423- pl_module : pl .LightningModule ,
424- outputs : Mapping [str , Any ],
425- batch : list [Batch ] | Batch ,
426- batch_idx : int ,
427- ) -> None :
428- n = sum (b .meta .batchsize for b in batch ) if isinstance (batch , list ) else batch .meta .batchsize
429- self ._epoch_bar .update (n )
430-
431-
432307def main ():
433308 np .seterr (all = "raise" )
434309 cv2 .setNumThreads (1 )
@@ -499,13 +374,13 @@ def main():
499374 save_weights_only = False ,
500375 )
501376
502- progress_cb = SimpleProgressBar (args .batchsize )
377+ progress_cb = train . SimpleProgressBar (args .batchsize )
503378
504- callbacks = [MetricsGraphing (), checkpoint_cb , progress_cb ]
379+ callbacks = [train . MetricsGraphing (), checkpoint_cb , progress_cb ]
505380
506381 swa_callback = None
507382 if args .swa :
508- swa_callback = SwaCallback (start_epoch = args .epochs * 2 // 3 )
383+ swa_callback = train . SwaCallback (start_epoch = args .epochs * 2 // 3 )
509384 callbacks .append (swa_callback )
510385
511386 # TODO: inf norm?
0 commit comments