Skip to content

Commit baa8fb5

Browse files
authored
fix overwriting best.pt every validation
1 parent a123419 commit baa8fb5

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

the_well/benchmark/trainer/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def train(self):
460460
self.save_model(
461461
epoch, val_loss, os.path.join(self.checkpoint_folder, "best.pt")
462462
)
463+
self.best_val_loss = val_loss
463464
# Check if time for expensive validation - periodic or final
464465
if epoch % self.rollout_val_frequency == 0 or (epoch == self.max_epoch):
465466
logger.info(

0 commit comments

Comments
 (0)