Skip to content

Commit 7bd6146

Browse files
committed
basic workflow: change default for save best only
1 parent 67057f0 commit 7bd6146

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

bayesflow/workflows/basic_workflow.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
checkpoint_filepath: str = None,
3333
checkpoint_name: str = "model",
3434
save_weights_only: bool = False,
35-
save_best_only: bool = True,
35+
save_best_only: bool = False,
3636
inference_variables: Sequence[str] | str = "theta",
3737
inference_conditions: Sequence[str] | str = "x",
3838
summary_variables: Sequence[str] | str = None,
@@ -67,8 +67,10 @@ def __init__(
6767
save_weights_only : bool, optional
6868
If True, only the model weights will be saved during checkpointing (default is False).
6969
save_best_only: bool, optional
70-
If only the latest best model according to the quantity monitored (loss or validation) at the end of
71-
each epoch will be saved. Consider setting to False when using FlowMatching (default is True)
70+
If only the best model according to the quantity monitored (loss or validation) at the end of
71+
each epoch will be saved instead of the last model (default is False). Use with caution,
72+
as some losses (e.g. flow matching) do not reliably reflect model performance, and outliers in the
73+
validation data can cause unwanted effects.
7274
inference_variables : Sequence[str] or str, optional
7375
Variables for inference as a sequence of strings or a single string (default is "theta").
7476
Important for automating diagnostics!
@@ -114,6 +116,25 @@ def __init__(
114116
self.checkpoint_name = checkpoint_name
115117
self.save_weights_only = save_weights_only
116118
self.save_best_only = save_best_only
119+
if self.checkpoint_filepath is not None:
120+
if self.save_weights_only:
121+
file_ext = self.checkpoint_name + ".weights.h5"
122+
else:
123+
file_ext = self.checkpoint_name + ".keras"
124+
checkpoint_full_filepath = os.path.join(self.checkpoint_filepath, file_ext)
125+
if os.path.exists(checkpoint_filepath):
126+
msg = (
127+
f"Checkpoint file exists: '{checkpoint_full_filepath}'.\n"
128+
"Existing checkpoints can _not_ be restored/loaded using this workflow. "
129+
"Upon refitting, the checkpoints will be overwritten."
130+
)
131+
if not self.save_weights_only:
132+
msg += (
133+
" To load the stored approximator from the checkpoint, "
134+
"use approximator = keras.saving.load_model(...)"
135+
)
136+
137+
logging.warning(msg)
117138
self.history = None
118139

119140
@staticmethod

0 commit comments

Comments
 (0)