Skip to content

Commit 12bade3

Browse files
committed
Add autograph to train_from_presim [skip ci]
1 parent 8edc54a commit 12bade3

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

bayesflow/trainers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,13 @@ def train_from_presimulation(self, presimulation_path, optimizer, save_checkpoin
528528
-------
529529
losses : dict or pandas.DataFrame
530530
A dictionary or pandas.DataFrame storing the losses across epochs and iterations
531-
"""
531+
"""
532+
533+
# Compile update function, if specified
534+
if use_autograph:
535+
_backprop_step = tf.function(backprop_step, reduce_retracing=True)
536+
else:
537+
_backprop_step = _backprop_step
532538

533539
# Use default loading function if none is provided
534540
if custom_loader is None:
@@ -567,7 +573,7 @@ def train_from_presimulation(self, presimulation_path, optimizer, save_checkpoin
567573

568574
# Like the number of iterations, the batch size is inferred from presimulated dictionary or list
569575
batch_size = epoch_data[index][DEFAULT_KEYS['sim_data']].shape[0]
570-
loss = self._train_step(batch_size, input_dict, **kwargs)
576+
loss = self._train_step(batch_size, _backprop_step, input_dict, **kwargs)
571577

572578
# Store returned loss
573579
self.loss_history.add_entry(ep, loss)

0 commit comments

Comments
 (0)