Skip to content

Commit 56a8dd5

Browse files
committed
Fix last epoch validation loss not saving
1 parent 9a5b125 commit 56a8dd5

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

bayesflow/trainers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ def train_online(
454454
p_bar.update(1)
455455

456456
# Store and compute validation loss, if specified
457-
self._save_trainer(save_checkpoint)
458457
self._validation(ep, validation_sims, **kwargs)
458+
self._save_trainer(save_checkpoint)
459459

460460
# Check early stopping, if specified
461461
if self._check_early_stopping(early_stopper):
@@ -579,13 +579,13 @@ def train_offline(
579579
# Format for display on progress bar
580580
disp_str = format_loss_string(ep, bi, loss, avg_dict, lr=lr, it_str="Batch")
581581

582-
# Update progress
582+
# Update progress bar
583583
p_bar.set_postfix_str(disp_str, refresh=False)
584584
p_bar.update(1)
585585

586586
# Store and compute validation loss, if specified
587-
self._save_trainer(save_checkpoint)
588587
self._validation(ep, validation_sims, **kwargs)
588+
self._save_trainer(save_checkpoint)
589589

590590
# Check early stopping, if specified
591591
if self._check_early_stopping(early_stopper):
@@ -762,15 +762,14 @@ def train_from_presimulation(
762762
p_bar.update(1)
763763

764764
# Store after each epoch, if specified
765-
self._save_trainer(save_checkpoint)
766-
767765
self._validation(ep, validation_sims, **kwargs)
766+
self._save_trainer(save_checkpoint)
768767

769768
# Check early stopping, if specified
770769
if self._check_early_stopping(early_stopper):
771770
break
772771

773-
# Remove reference to optimizer, if not set to persistent
772+
# Remove optimizer reference, if not set as persistent
774773
if not reuse_optimizer:
775774
self.optimizer = None
776775
return self.loss_history.get_plottable()
@@ -906,8 +905,8 @@ def train_experience_replay(
906905
p_bar.update(1)
907906

908907
# Store and compute validation loss, if specified
909-
self._save_trainer(save_checkpoint)
910908
self._validation(ep, validation_sims, **kwargs)
909+
self._save_trainer(save_checkpoint)
911910

912911
# Check early stopping, if specified
913912
if self._check_early_stopping(early_stopper):

0 commit comments

Comments
 (0)