@@ -709,6 +709,78 @@ def printCommandLineArgs(args):
709709 print (f" { arg } : { getattr (args , arg )} " )
710710 print ("}" )
711711
712+ # Function to save model checkpoint
713+ def save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch , checkpoint_dir = "checkpoints" ):
714+ """
715+ Save model checkpoint including model state, optimizer state, and training metrics
716+
717+ Parameters:
718+ model: The neural network model
719+ optimizer: The optimizer used for training
720+ train_loss: List of training losses
721+ val_loss: List of validation losses
722+ epoch: Current epoch number
723+ checkpoint_dir: Directory to save checkpoints
724+ """
725+ os .makedirs (checkpoint_dir , exist_ok = True )
726+
727+ checkpoint_path = os .path .join (checkpoint_dir , f"xpoint_model_epoch_{ epoch } .pt" )
728+
729+ # Create checkpoint dictionary
730+ checkpoint = {
731+ 'epoch' : epoch ,
732+ 'model_state_dict' : model .state_dict (),
733+ 'optimizer_state_dict' : optimizer .state_dict (),
734+ 'train_loss' : train_loss ,
735+ 'val_loss' : val_loss
736+ }
737+
738+ # Save checkpoint
739+ torch .save (checkpoint , checkpoint_path )
740+ print (f"Model checkpoint saved at epoch { epoch } to { checkpoint_path } " )
741+
742+ # Save the latest model separately for easy loading
743+ latest_path = os .path .join (checkpoint_dir , "xpoint_model_latest.pt" )
744+ torch .save (checkpoint , latest_path )
745+ print (f"Latest model saved to { latest_path } " )
746+
747+
748+
749+ # Function to load model checkpoint
750+ def load_model_checkpoint (model , optimizer , checkpoint_path ):
751+ """
752+ Load model checkpoint
753+
754+ Parameters:
755+ model: The neural network model to load weights into
756+ optimizer: The optimizer to load state into
757+ checkpoint_path: Path to the checkpoint file
758+
759+ Returns:
760+ model: Updated model with loaded weights
761+ optimizer: Updated optimizer with loaded state
762+ epoch: Last saved epoch number
763+ train_loss: List of training losses
764+ val_loss: List of validation losses
765+ """
766+ if not os .path .exists (checkpoint_path ):
767+ print (f"No checkpoint found at { checkpoint_path } " )
768+ return model , optimizer , 0 , [], []
769+
770+ print (f"Loading checkpoint from { checkpoint_path } " )
771+ checkpoint = torch .load (checkpoint_path )
772+
773+ model .load_state_dict (checkpoint ['model_state_dict' ])
774+ optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
775+
776+ epoch = checkpoint ['epoch' ]
777+ train_loss = checkpoint ['train_loss' ]
778+ val_loss = checkpoint ['val_loss' ]
779+
780+ print (f"Loaded checkpoint from epoch { epoch } " )
781+ return model , optimizer , epoch , train_loss , val_loss
782+
783+
712784def main ():
713785 args = parseCommandLineArgs ()
714786 checkCommandLineArgs (args )
@@ -741,17 +813,35 @@ def main():
741813 criterion = DiceLoss (smooth = 1.0 )
742814 optimizer = optim .Adam (model .parameters (), lr = args .learningRate )
743815
816+ checkpoint_dir = "checkpoints"
817+ os .makedirs (checkpoint_dir , exist_ok = True )
818+ latest_checkpoint_path = os .path .join (checkpoint_dir , "xpoint_model_latest.pt" )
819+ start_epoch = 0
820+ train_loss = []
821+ val_loss = []
822+
823+ if os .path .exists (latest_checkpoint_path ):
824+ model , optimizer , start_epoch , train_loss , val_loss = load_model_checkpoint (
825+ model , optimizer , latest_checkpoint_path
826+ )
827+ print (f"Resuming training from epoch { start_epoch + 1 } " )
828+ else :
829+ print ("Starting training from scratch" )
830+
744831 t2 = timer ()
745832 print ("time (s) to prepare model: " + str (t2 - t1 ))
746833
747834 train_loss = []
748835 val_loss = []
749-
836+
750837 num_epochs = args .epochs
751- for epoch in range (num_epochs ):
838+ for epoch in range (start_epoch , num_epochs ):
752839 train_loss .append (train_one_epoch (model , train_loader , criterion , optimizer , device ))
753840 val_loss .append (validate_one_epoch (model , val_loader , criterion , device ))
754841 print (f"[Epoch { epoch + 1 } /{ num_epochs } ] TrainLoss={ train_loss [- 1 ]} ValLoss={ val_loss [- 1 ]} " )
842+
843+ # Save model checkpoint after each epoch
844+ save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch + 1 , checkpoint_dir )
755845
756846 plot_training_history (train_loss , val_loss )
757847 print ("time (s) to train model: " + str (timer ()- t2 ))
0 commit comments