Skip to content

Commit a6f3566

Browse files
authored
Merge pull request #10 from SCOREC/luy/loadAndSaveModel
luy/load and save model
2 parents 36ea19b + 261bc0e commit a6f3566

File tree

1 file changed

+92
-2
lines changed

1 file changed

+92
-2
lines changed

XPointMLTest.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,78 @@ def printCommandLineArgs(args):
704704
print(f" {arg}: {getattr(args, arg)}")
705705
print("}")
706706

707+
# Function to save model checkpoint
708+
def save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch, checkpoint_dir="checkpoints"):
709+
"""
710+
Save model checkpoint including model state, optimizer state, and training metrics
711+
712+
Parameters:
713+
model: The neural network model
714+
optimizer: The optimizer used for training
715+
train_loss: List of training losses
716+
val_loss: List of validation losses
717+
epoch: Current epoch number
718+
checkpoint_dir: Directory to save checkpoints
719+
"""
720+
os.makedirs(checkpoint_dir, exist_ok=True)
721+
722+
checkpoint_path = os.path.join(checkpoint_dir, f"xpoint_model_epoch_{epoch}.pt")
723+
724+
# Create checkpoint dictionary
725+
checkpoint = {
726+
'epoch': epoch,
727+
'model_state_dict': model.state_dict(),
728+
'optimizer_state_dict': optimizer.state_dict(),
729+
'train_loss': train_loss,
730+
'val_loss': val_loss
731+
}
732+
733+
# Save checkpoint
734+
torch.save(checkpoint, checkpoint_path)
735+
print(f"Model checkpoint saved at epoch {epoch} to {checkpoint_path}")
736+
737+
# Save the latest model separately for easy loading
738+
latest_path = os.path.join(checkpoint_dir, "xpoint_model_latest.pt")
739+
torch.save(checkpoint, latest_path)
740+
print(f"Latest model saved to {latest_path}")
741+
742+
743+
744+
# Function to load model checkpoint
745+
def load_model_checkpoint(model, optimizer, checkpoint_path):
746+
"""
747+
Load model checkpoint
748+
749+
Parameters:
750+
model: The neural network model to load weights into
751+
optimizer: The optimizer to load state into
752+
checkpoint_path: Path to the checkpoint file
753+
754+
Returns:
755+
model: Updated model with loaded weights
756+
optimizer: Updated optimizer with loaded state
757+
epoch: Last saved epoch number
758+
train_loss: List of training losses
759+
val_loss: List of validation losses
760+
"""
761+
if not os.path.exists(checkpoint_path):
762+
print(f"No checkpoint found at {checkpoint_path}")
763+
return model, optimizer, 0, [], []
764+
765+
print(f"Loading checkpoint from {checkpoint_path}")
766+
checkpoint = torch.load(checkpoint_path)
767+
768+
model.load_state_dict(checkpoint['model_state_dict'])
769+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
770+
771+
epoch = checkpoint['epoch']
772+
train_loss = checkpoint['train_loss']
773+
val_loss = checkpoint['val_loss']
774+
775+
print(f"Loaded checkpoint from epoch {epoch}")
776+
return model, optimizer, epoch, train_loss, val_loss
777+
778+
707779
def main():
708780
args = parseCommandLineArgs()
709781
checkCommandLineArgs(args)
@@ -737,17 +809,35 @@ def main():
737809
criterion = DiceLoss(smooth=1.0)
738810
optimizer = optim.Adam(model.parameters(), lr=args.learningRate)
739811

812+
checkpoint_dir = "checkpoints"
813+
os.makedirs(checkpoint_dir, exist_ok=True)
814+
latest_checkpoint_path = os.path.join(checkpoint_dir, "xpoint_model_latest.pt")
815+
start_epoch = 0
816+
train_loss = []
817+
val_loss = []
818+
819+
if os.path.exists(latest_checkpoint_path):
820+
model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint(
821+
model, optimizer, latest_checkpoint_path
822+
)
823+
print(f"Resuming training from epoch {start_epoch+1}")
824+
else:
825+
print("Starting training from scratch")
826+
740827
t2 = timer()
741828
print("time (s) to prepare model: " + str(t2-t1))
742829

743830
train_loss = []
744831
val_loss = []
745-
832+
746833
num_epochs = args.epochs
747-
for epoch in range(num_epochs):
834+
for epoch in range(start_epoch, num_epochs):
748835
train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device))
749836
val_loss.append(validate_one_epoch(model, val_loader, criterion, device))
750837
print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}")
838+
839+
# Save model checkpoint after each epoch
840+
save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir)
751841

752842
plot_training_history(train_loss, val_loss)
753843
print("time (s) to train model: " + str(timer()-t2))

0 commit comments

Comments
 (0)