diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 93c3c2788..16bf362db 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -32,7 +32,7 @@ def __init__( checkpoint_filepath: str = None, checkpoint_name: str = "model", save_weights_only: bool = False, - save_best_only: bool = True, + save_best_only: bool = False, inference_variables: Sequence[str] | str = "theta", inference_conditions: Sequence[str] | str = "x", summary_variables: Sequence[str] | str = None, @@ -67,8 +67,10 @@ def __init__( save_weights_only : bool, optional If True, only the model weights will be saved during checkpointing (default is False). save_best_only: bool, optional - If only the latest best model according to the quantity monitored (loss or validation) at the end of - each epoch will be saved. Consider setting to False when using FlowMatching (default is True) + If only the best model according to the quantity monitored (loss or validation) at the end of + each epoch will be saved instead of the last model (default is False). Use with caution, + as some losses (e.g. flow matching) do not reliably reflect model performance, and outliers in the + validation data can cause unwanted effects. inference_variables : Sequence[str] or str, optional Variables for inference as a sequence of strings or a single string (default is "theta"). Important for automating diagnostics! @@ -114,6 +116,25 @@ def __init__( self.checkpoint_name = checkpoint_name self.save_weights_only = save_weights_only self.save_best_only = save_best_only + if self.checkpoint_filepath is not None: + if self.save_weights_only: + file_ext = self.checkpoint_name + ".weights.h5" + else: + file_ext = self.checkpoint_name + ".keras" + checkpoint_full_filepath = os.path.join(self.checkpoint_filepath, file_ext) + if os.path.exists(checkpoint_full_filepath): + msg = ( + f"Checkpoint file exists: '{checkpoint_full_filepath}'.\n" + "Existing checkpoints can _not_ be restored/loaded using this workflow. " + "Upon refitting, the checkpoints will be overwritten." + ) + if not self.save_weights_only: + msg += ( + " To load the stored approximator from the checkpoint, " + "use approximator = keras.saving.load_model(...)" + ) + + logging.warning(msg) self.history = None @staticmethod