Skip to content

Commit 261bc0e

Browse files
committed
load and save model in the main function
1 parent 4aaa16c commit 261bc0e

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

XPointMLTest.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,17 +809,35 @@ def main():
809809
criterion = DiceLoss(smooth=1.0)
810810
optimizer = optim.Adam(model.parameters(), lr=args.learningRate)
811811

812+
checkpoint_dir = "checkpoints"
813+
os.makedirs(checkpoint_dir, exist_ok=True)
814+
latest_checkpoint_path = os.path.join(checkpoint_dir, "xpoint_model_latest.pt")
815+
start_epoch = 0
816+
train_loss = []
817+
val_loss = []
818+
819+
if os.path.exists(latest_checkpoint_path):
820+
model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint(
821+
model, optimizer, latest_checkpoint_path
822+
)
823+
print(f"Resuming training from epoch {start_epoch+1}")
824+
else:
825+
print("Starting training from scratch")
826+
812827
t2 = timer()
813828
print("time (s) to prepare model: " + str(t2-t1))
814829

815830
train_loss = []
816831
val_loss = []
817-
832+
818833
num_epochs = args.epochs
819-
for epoch in range(num_epochs):
834+
for epoch in range(start_epoch, num_epochs):
820835
train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device))
821836
val_loss.append(validate_one_epoch(model, val_loader, criterion, device))
822837
print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}")
838+
839+
# Save model checkpoint after each epoch
840+
save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir)
823841

824842
plot_training_history(train_loss, val_loss)
825843
print("time (s) to train model: " + str(timer()-t2))

0 commit comments

Comments
 (0)