@@ -212,16 +212,25 @@ class StoreCallback(CallbackStep):
212212 Save every this many epochs (default: 100)
213213 """
214214
215- def __init__ (self , pdf_model , replica_paths , check_freq = 100 ):
215+ def __init__ (self , pdf_model , replica_paths , stopping_object , check_freq = 100 ):
216216 super ().__init__ ()
217217 self .check_freq = check_freq
218218 self .pdf_model = pdf_model
219219 self .weight_dirs = []
220+ self .stopping_object = stopping_object
220221 for path in replica_paths :
221222 weight_dir = path / "parameters"
222223 weight_dir .mkdir (parents = True , exist_ok = True )
223224 self .weight_dirs .append (weight_dir )
224225
226+ def _save_weights (self , epoch , tr_weights , weight_dir ):
227+
228+ filepath = weight_dir / f"params_{ epoch + 1 } .npz"
229+ # save parameters as expected by colibri
230+ trainable_weights_flat = np .concatenate ([np .asarray (w ).flatten () for w in tr_weights ])
231+ np .savez (filepath , params = trainable_weights_flat )
232+ log .info (f"Saved parameters at epoch { epoch + 1 } in { filepath } " )
233+
225234 def on_step_end (self , epoch , logs = None ):
226235 """Function to be called at the end of every epoch
227236 Every ``check_freq`` number of epochs, the parameters of the model will
@@ -230,14 +239,15 @@ def on_step_end(self, epoch, logs=None):
230239 if ((epoch + 1 ) % self .check_freq ) == 0 :
231240 pdf_replicas = self .pdf_model .split_replicas ()
232241 for replica_model , weight_dir in zip (pdf_replicas , self .weight_dirs ):
233- filepath = weight_dir / f"params_{ epoch + 1 } .npz"
234- # save parameters as expected by colibri
235- trainable_weights_flat = np .concatenate (
236- [w .numpy ().flatten () for w in replica_model .trainable_weights ]
237- )
238- np .savez (filepath , params = trainable_weights_flat )
239- # replica_model.save_weights(filepath)
240- log .info (f"Saved parameters at epoch { epoch + 1 } in { filepath } " )
242+ weights = replica_model .trainable_weights
243+ self ._save_weights (epoch , weights , weight_dir )
244+
245+ def on_train_end (self , logs = None ):
246+ """Store the best parameters"""
247+ for idx , weight_dir in enumerate (self .weight_dirs ):
248+ best_epoch = self .stopping_object ._best_epochs [idx ]
249+ best_weights = self .stopping_object ._best_weights [idx ]['all_NNs' ]
250+ self ._save_weights (best_epoch , best_weights , weight_dir )
241251
242252
243253def gen_tensorboard_callback (log_dir , profiling = False , histogram_freq = 0 ):
0 commit comments