Skip to content

Commit 9d8d10b

Browse files
committed
control min training loss from command line
1 parent b628860 commit 9d8d10b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

XPointMLTest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ def parseCommandLineArgs():
590590
help='specify the number of the first frame used for validation')
591591
parser.add_argument('--validationFrameLast', type=int, default=150,
592592
help='specify the number of the last frame (exclusive) used for validation')
593+
parser.add_argument('--minTrainingLoss', type=int, default=3,
594+
help='minimum reduction in training loss in orders of magnitude')
593595
parser.add_argument('--paramFile', type=Path, default=None,
594596
help='''
595597
specify the path to the parameter txt file, the parent
@@ -632,6 +634,10 @@ def checkCommandLineArgs(args):
632634
print(f"validation frame range isn't valid... exiting")
633635
sys.exit()
634636

637+
if args.minTrainingLoss < 0:
638+
print(f"minTrainingLoss must be >= 0... exiting")
639+
sys.exit()
640+
635641
def main():
636642
args = parseCommandLineArgs()
637643
checkCommandLineArgs(args)
@@ -671,7 +677,7 @@ def main():
671677

672678
print("time (s) to train model: " + str(timer()-t2))
673679

674-
requiredLossDecreaseMagnitude = 3;
680+
requiredLossDecreaseMagnitude = args.minTrainingLoss
675681
if np.log10(abs(train_loss[0]/train_loss[-1])) < requiredLossDecreaseMagnitude:
676682
print(f"TrainLoss reduced by less than {requiredLossDecreaseMagnitude} orders of magnitude: "
677683
f"initial {train_loss[0]} final {train_loss[-1]} ... exiting")

0 commit comments

Comments
 (0)