Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions bayesflow/workflows/basic_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
checkpoint_filepath: str = None,
checkpoint_name: str = "model",
save_weights_only: bool = False,
save_best_only: bool = True,
save_best_only: bool = False,
inference_variables: Sequence[str] | str = "theta",
inference_conditions: Sequence[str] | str = "x",
summary_variables: Sequence[str] | str = None,
Expand Down Expand Up @@ -67,8 +67,10 @@ def __init__(
save_weights_only : bool, optional
If True, only the model weights will be saved during checkpointing (default is False).
save_best_only: bool, optional
If only the latest best model according to the quantity monitored (loss or validation) at the end of
each epoch will be saved. Consider setting to False when using FlowMatching (default is True)
If only the best model according to the quantity monitored (loss or validation) at the end of
each epoch will be saved instead of the last model (default is False). Use with caution,
as some losses (e.g. flow matching) do not reliably reflect model performance, and outliers in the
validation data can cause unwanted effects.
inference_variables : Sequence[str] or str, optional
Variables for inference as a sequence of strings or a single string (default is "theta").
Important for automating diagnostics!
Expand Down Expand Up @@ -114,6 +116,25 @@ def __init__(
self.checkpoint_name = checkpoint_name
self.save_weights_only = save_weights_only
self.save_best_only = save_best_only
if self.checkpoint_filepath is not None:
if self.save_weights_only:
file_ext = self.checkpoint_name + ".weights.h5"
else:
file_ext = self.checkpoint_name + ".keras"
checkpoint_full_filepath = os.path.join(self.checkpoint_filepath, file_ext)
if os.path.exists(checkpoint_full_filepath):
msg = (
f"Checkpoint file exists: '{checkpoint_full_filepath}'.\n"
"Existing checkpoints can _not_ be restored/loaded using this workflow. "
"Upon refitting, the checkpoints will be overwritten."
)
if not self.save_weights_only:
msg += (
" To load the stored approximator from the checkpoint, "
"use approximator = keras.saving.load_model(...)"
)

logging.warning(msg)
self.history = None

@staticmethod
Expand Down