@@ -32,7 +32,7 @@ def __init__(
3232 checkpoint_filepath : str = None ,
3333 checkpoint_name : str = "model" ,
3434 save_weights_only : bool = False ,
35- save_best_only : bool = True ,
35+ save_best_only : bool = False ,
3636 inference_variables : Sequence [str ] | str = "theta" ,
3737 inference_conditions : Sequence [str ] | str = "x" ,
3838 summary_variables : Sequence [str ] | str = None ,
@@ -67,8 +67,10 @@ def __init__(
6767 save_weights_only : bool, optional
6868 If True, only the model weights will be saved during checkpointing (default is False).
6969 save_best_only: bool, optional
70- If only the latest best model according to the quantity monitored (loss or validation) at the end of
71- each epoch will be saved. Consider setting to False when using FlowMatching (default is True)
70+ If only the best model according to the quantity monitored (loss or validation) at the end of
71+ each epoch will be saved instead of the last model (default is False). Use with caution,
72+ as some losses (e.g. flow matching) do not reliably reflect model performance, and outliers in the
73+ validation data can cause unwanted effects.
7274 inference_variables : Sequence[str] or str, optional
7375 Variables for inference as a sequence of strings or a single string (default is "theta").
7476 Important for automating diagnostics!
@@ -114,6 +116,25 @@ def __init__(
114116 self .checkpoint_name = checkpoint_name
115117 self .save_weights_only = save_weights_only
116118 self .save_best_only = save_best_only
119+ if self .checkpoint_filepath is not None :
120+ if self .save_weights_only :
121+ file_ext = self .checkpoint_name + ".weights.h5"
122+ else :
123+ file_ext = self .checkpoint_name + ".keras"
124+ checkpoint_full_filepath = os .path .join (self .checkpoint_filepath , file_ext )
125+ if os .path .exists (checkpoint_full_filepath ):
126+ msg = (
127+ f"Checkpoint file exists: '{ checkpoint_full_filepath } '.\n "
128+ "Existing checkpoints can _not_ be restored/loaded using this workflow. "
129+ "Upon refitting, the checkpoints will be overwritten."
130+ )
131+ if not self .save_weights_only :
132+ msg += (
133+ " To load the stored approximator from the checkpoint, "
134+ "use approximator = keras.saving.load_model(...)"
135+ )
136+
137+ logging .warning (msg )
117138 self .history = None
118139
119140 @staticmethod
0 commit comments