Skip to content

Commit 6e0737b

Browse files
committed
cmd line option to control checkpoint frequency
1 parent bf75326 commit 6e0737b

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

XPointMLTest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ def parseCommandLineArgs():
645645
help='specify the number of the last frame (exclusive) used for validation')
646646
parser.add_argument('--minTrainingLoss', type=int, default=3,
647647
help='minimum reduction in training loss in orders of magnitude')
648+
parser.add_argument('--checkPointFrequency', type=int, default=10,
649+
help='number of epochs between checkpoints')
648650
parser.add_argument('--paramFile', type=Path, default=None,
649651
help='''
650652
specify the path to the parameter txt file, the parent
@@ -703,6 +705,10 @@ def checkCommandLineArgs(args):
703705
print(f"minTrainingLoss must be >= 0... exiting")
704706
sys.exit()
705707

708+
if args.checkPointFrequency < 0:
709+
print(f"checkPointFrequency must be >= 0... exiting")
710+
sys.exit()
711+
706712
def printCommandLineArgs(args):
707713
print("Config {")
708714
for arg in vars(args):
@@ -846,7 +852,8 @@ def main():
846852
print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}")
847853

848854
# Save model checkpoint after each epoch
849-
save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir)
855+
if epoch % args.checkPointFrequency == 0:
856+
save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir)
850857

851858
plot_training_history(train_loss, val_loss)
852859
print("time (s) to train model: " + str(timer()-t2))

0 commit comments

Comments
 (0)