@@ -704,6 +704,78 @@ def printCommandLineArgs(args):
704704 print (f" { arg } : { getattr (args , arg )} " )
705705 print ("}" )
706706
707+ # Function to save model checkpoint
708+ def save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch , checkpoint_dir = "checkpoints" ):
709+ """
710+ Save model checkpoint including model state, optimizer state, and training metrics
711+
712+ Parameters:
713+ model: The neural network model
714+ optimizer: The optimizer used for training
715+ train_loss: List of training losses
716+ val_loss: List of validation losses
717+ epoch: Current epoch number
718+ checkpoint_dir: Directory to save checkpoints
719+ """
720+ os .makedirs (checkpoint_dir , exist_ok = True )
721+
722+ checkpoint_path = os .path .join (checkpoint_dir , f"xpoint_model_epoch_{ epoch } .pt" )
723+
724+ # Create checkpoint dictionary
725+ checkpoint = {
726+ 'epoch' : epoch ,
727+ 'model_state_dict' : model .state_dict (),
728+ 'optimizer_state_dict' : optimizer .state_dict (),
729+ 'train_loss' : train_loss ,
730+ 'val_loss' : val_loss
731+ }
732+
733+ # Save checkpoint
734+ torch .save (checkpoint , checkpoint_path )
735+ print (f"Model checkpoint saved at epoch { epoch } to { checkpoint_path } " )
736+
737+ # Save the latest model separately for easy loading
738+ latest_path = os .path .join (checkpoint_dir , "xpoint_model_latest.pt" )
739+ torch .save (checkpoint , latest_path )
740+ print (f"Latest model saved to { latest_path } " )
741+
742+
743+
744+ # Function to load model checkpoint
745+ def load_model_checkpoint (model , optimizer , checkpoint_path ):
746+ """
747+ Load model checkpoint
748+
749+ Parameters:
750+ model: The neural network model to load weights into
751+ optimizer: The optimizer to load state into
752+ checkpoint_path: Path to the checkpoint file
753+
754+ Returns:
755+ model: Updated model with loaded weights
756+ optimizer: Updated optimizer with loaded state
757+ epoch: Last saved epoch number
758+ train_loss: List of training losses
759+ val_loss: List of validation losses
760+ """
761+ if not os .path .exists (checkpoint_path ):
762+ print (f"No checkpoint found at { checkpoint_path } " )
763+ return model , optimizer , 0 , [], []
764+
765+ print (f"Loading checkpoint from { checkpoint_path } " )
766+ checkpoint = torch .load (checkpoint_path )
767+
768+ model .load_state_dict (checkpoint ['model_state_dict' ])
769+ optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
770+
771+ epoch = checkpoint ['epoch' ]
772+ train_loss = checkpoint ['train_loss' ]
773+ val_loss = checkpoint ['val_loss' ]
774+
775+ print (f"Loaded checkpoint from epoch { epoch } " )
776+ return model , optimizer , epoch , train_loss , val_loss
777+
778+
707779def main ():
708780 args = parseCommandLineArgs ()
709781 checkCommandLineArgs (args )
@@ -737,17 +809,35 @@ def main():
737809 criterion = DiceLoss (smooth = 1.0 )
738810 optimizer = optim .Adam (model .parameters (), lr = args .learningRate )
739811
812+ checkpoint_dir = "checkpoints"
813+ os .makedirs (checkpoint_dir , exist_ok = True )
814+ latest_checkpoint_path = os .path .join (checkpoint_dir , "xpoint_model_latest.pt" )
815+ start_epoch = 0
816+ train_loss = []
817+ val_loss = []
818+
819+ if os .path .exists (latest_checkpoint_path ):
820+ model , optimizer , start_epoch , train_loss , val_loss = load_model_checkpoint (
821+ model , optimizer , latest_checkpoint_path
822+ )
823+ print (f"Resuming training from epoch { start_epoch + 1 } " )
824+ else :
825+ print ("Starting training from scratch" )
826+
740827 t2 = timer ()
741828 print ("time (s) to prepare model: " + str (t2 - t1 ))
742829
743830 train_loss = []
744831 val_loss = []
745-
832+
746833 num_epochs = args .epochs
747- for epoch in range (num_epochs ):
834+ for epoch in range (start_epoch , num_epochs ):
748835 train_loss .append (train_one_epoch (model , train_loader , criterion , optimizer , device ))
749836 val_loss .append (validate_one_epoch (model , val_loader , criterion , device ))
750837 print (f"[Epoch { epoch + 1 } /{ num_epochs } ] TrainLoss={ train_loss [- 1 ]} ValLoss={ val_loss [- 1 ]} " )
838+
839+ # Save model checkpoint after each epoch
840+ save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch + 1 , checkpoint_dir )
751841
752842 plot_training_history (train_loss , val_loss )
753843 print ("time (s) to train model: " + str (timer ()- t2 ))
0 commit comments