Skip to content

Commit 3b196c6

Browse files
Merge pull request #139 from elseml/Development
Fix offline training for model comparison ignoring shared context
2 parents 512046c + 56a8dd5 commit 3b196c6

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

bayesflow/helper_classes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,18 @@ def __init__(self, forward_dict, batch_size, buffer_size=1024):
141141
self.iters = [iter(d) for d in self.datasets]
142142
self.batch_size = batch_size
143143

144+
# Include further keys (= shared context) from forward_dict
145+
self.further_keys = {}
146+
for key, value in forward_dict.items():
147+
if key not in [DEFAULT_KEYS["model_outputs"], DEFAULT_KEYS["model_indices"]]:
148+
self.further_keys[key] = value
149+
144150
def __next__(self):
145151
if self.current_it < self.num_batches:
146152
outputs = [next(d) for d in self.iters]
147153
output_dict = {DEFAULT_KEYS["model_outputs"]: outputs, DEFAULT_KEYS["model_indices"]: self.model_indices}
154+
if self.further_keys:
155+
output_dict.update(self.further_keys)
148156
self.current_it += 1
149157
return output_dict
150158
self.current_it = 0

bayesflow/trainers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ def train_online(
454454
p_bar.update(1)
455455

456456
# Store and compute validation loss, if specified
457-
self._save_trainer(save_checkpoint)
458457
self._validation(ep, validation_sims, **kwargs)
458+
self._save_trainer(save_checkpoint)
459459

460460
# Check early stopping, if specified
461461
if self._check_early_stopping(early_stopper):
@@ -579,13 +579,13 @@ def train_offline(
579579
# Format for display on progress bar
580580
disp_str = format_loss_string(ep, bi, loss, avg_dict, lr=lr, it_str="Batch")
581581

582-
# Update progress
582+
# Update progress bar
583583
p_bar.set_postfix_str(disp_str, refresh=False)
584584
p_bar.update(1)
585585

586586
# Store and compute validation loss, if specified
587-
self._save_trainer(save_checkpoint)
588587
self._validation(ep, validation_sims, **kwargs)
588+
self._save_trainer(save_checkpoint)
589589

590590
# Check early stopping, if specified
591591
if self._check_early_stopping(early_stopper):
@@ -762,15 +762,14 @@ def train_from_presimulation(
762762
p_bar.update(1)
763763

764764
# Store after each epoch, if specified
765-
self._save_trainer(save_checkpoint)
766-
767765
self._validation(ep, validation_sims, **kwargs)
766+
self._save_trainer(save_checkpoint)
768767

769768
# Check early stopping, if specified
770769
if self._check_early_stopping(early_stopper):
771770
break
772771

773-
# Remove reference to optimizer, if not set to persistent
772+
# Remove optimizer reference, if not set as persistent
774773
if not reuse_optimizer:
775774
self.optimizer = None
776775
return self.loss_history.get_plottable()
@@ -906,8 +905,8 @@ def train_experience_replay(
906905
p_bar.update(1)
907906

908907
# Store and compute validation loss, if specified
909-
self._save_trainer(save_checkpoint)
910908
self._validation(ep, validation_sims, **kwargs)
909+
self._save_trainer(save_checkpoint)
911910

912911
# Check early stopping, if specified
913912
if self._check_early_stopping(early_stopper):

0 commit comments

Comments
 (0)