@@ -645,6 +645,8 @@ def parseCommandLineArgs():
645645 help = 'specify the number of the last frame (exclusive) used for validation' )
646646 parser .add_argument ('--minTrainingLoss' , type = int , default = 3 ,
647647 help = 'minimum reduction in training loss in orders of magnitude' )
648+ parser .add_argument ('--checkPointFrequency' , type = int , default = 10 ,
649+ help = 'number of epochs between checkpoints' )
648650 parser .add_argument ('--paramFile' , type = Path , default = None ,
649651 help = '''
650652 specify the path to the parameter txt file, the parent
@@ -703,6 +705,10 @@ def checkCommandLineArgs(args):
703705 print (f"minTrainingLoss must be >= 0... exiting" )
704706 sys .exit ()
705707
708+ if args .checkPointFrequency < 0 :
709+ print (f"checkPointFrequency must be >= 0... exiting" )
710+ sys .exit ()
711+
706712def printCommandLineArgs (args ):
707713 print ("Config {" )
708714 for arg in vars (args ):
@@ -846,7 +852,8 @@ def main():
846852 print (f"[Epoch { epoch + 1 } /{ num_epochs } ] TrainLoss={ train_loss [- 1 ]} ValLoss={ val_loss [- 1 ]} " )
847853
848854 # Save model checkpoint after each epoch
849- save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch + 1 , checkpoint_dir )
855+ if epoch % args .checkPointFrequency == 0 :
856+ save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch + 1 , checkpoint_dir )
850857
851858 plot_training_history (train_loss , val_loss )
852859 print ("time (s) to train model: " + str (timer ()- t2 ))
0 commit comments