@@ -809,17 +809,35 @@ def main():
809809 criterion = DiceLoss (smooth = 1.0 )
810810 optimizer = optim .Adam (model .parameters (), lr = args .learningRate )
811811
812+ checkpoint_dir = "checkpoints"
813+ os .makedirs (checkpoint_dir , exist_ok = True )
814+ latest_checkpoint_path = os .path .join (checkpoint_dir , "xpoint_model_latest.pt" )
815+ start_epoch = 0
816+ train_loss = []
817+ val_loss = []
818+
819+ if os .path .exists (latest_checkpoint_path ):
820+ model , optimizer , start_epoch , train_loss , val_loss = load_model_checkpoint (
821+ model , optimizer , latest_checkpoint_path
822+ )
823+ print (f"Resuming training from epoch { start_epoch + 1 } " )
824+ else :
825+ print ("Starting training from scratch" )
826+
812827 t2 = timer ()
813828 print ("time (s) to prepare model: " + str (t2 - t1 ))
814829
815830 train_loss = []
816831 val_loss = []
817-
832+
818833 num_epochs = args .epochs
819- for epoch in range (num_epochs ):
834+ for epoch in range (start_epoch , num_epochs ):
820835 train_loss .append (train_one_epoch (model , train_loader , criterion , optimizer , device ))
821836 val_loss .append (validate_one_epoch (model , val_loader , criterion , device ))
822837 print (f"[Epoch { epoch + 1 } /{ num_epochs } ] TrainLoss={ train_loss [- 1 ]} ValLoss={ val_loss [- 1 ]} " )
838+
839+ # Save model checkpoint after each epoch
840+ save_model_checkpoint (model , optimizer , train_loss , val_loss , epoch + 1 , checkpoint_dir )
823841
824842 plot_training_history (train_loss , val_loss )
825843 print ("time (s) to train model: " + str (timer ()- t2 ))
0 commit comments