Skip to content

Commit c92c918

Browse files
committed
Save best parameters
1 parent e7e9450 commit c92c918

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

n3fit/src/n3fit/backends/keras_backend/callbacks.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,16 +212,25 @@ class StoreCallback(CallbackStep):
212212
Save every this many epochs (default: 100)
213213
"""
214214

215-
def __init__(self, pdf_model, replica_paths, check_freq=100):
215+
def __init__(self, pdf_model, replica_paths, stopping_object, check_freq=100):
216216
super().__init__()
217217
self.check_freq = check_freq
218218
self.pdf_model = pdf_model
219219
self.weight_dirs = []
220+
self.stopping_object = stopping_object
220221
for path in replica_paths:
221222
weight_dir = path / "parameters"
222223
weight_dir.mkdir(parents=True, exist_ok=True)
223224
self.weight_dirs.append(weight_dir)
224225

226+
def _save_weights(self, epoch, tr_weights, weight_dir):
227+
228+
filepath = weight_dir / f"params_{epoch+1}.npz"
229+
# save parameters as expected by colibri
230+
trainable_weights_flat = np.concatenate([np.asarray(w).flatten() for w in tr_weights])
231+
np.savez(filepath, params=trainable_weights_flat)
232+
log.info(f"Saved parameters at epoch {epoch+1} in {filepath}")
233+
225234
def on_step_end(self, epoch, logs=None):
226235
"""Function to be called at the end of every epoch
227236
Every ``check_freq`` number of epochs, the parameters of the model will
@@ -230,14 +239,15 @@ def on_step_end(self, epoch, logs=None):
230239
if ((epoch + 1) % self.check_freq) == 0:
231240
pdf_replicas = self.pdf_model.split_replicas()
232241
for replica_model, weight_dir in zip(pdf_replicas, self.weight_dirs):
233-
filepath = weight_dir / f"params_{epoch+1}.npz"
234-
# save parameters as expected by colibri
235-
trainable_weights_flat = np.concatenate(
236-
[w.numpy().flatten() for w in replica_model.trainable_weights]
237-
)
238-
np.savez(filepath, params=trainable_weights_flat)
239-
# replica_model.save_weights(filepath)
240-
log.info(f"Saved parameters at epoch {epoch+1} in {filepath}")
242+
weights = replica_model.trainable_weights
243+
self._save_weights(epoch, weights, weight_dir)
244+
245+
def on_train_end(self, logs=None):
246+
"""Store the best parameters"""
247+
for idx, weight_dir in enumerate(self.weight_dirs):
248+
best_epoch = self.stopping_object._best_epochs[idx]
249+
best_weights = self.stopping_object._best_weights[idx]['all_NNs']
250+
self._save_weights(best_epoch, best_weights, weight_dir)
241251

242252

243253
def gen_tensorboard_callback(log_dir, profiling=False, histogram_freq=0):

n3fit/src/n3fit/model_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,10 @@ def _train_and_fit(self, training_model, stopping_object, epochs=100) -> bool:
752752
self.replica_path.parent / f"fit_replicas/replica_{r}" for r in self.replicas
753753
]
754754
checpoint_callback = callbacks.StoreCallback(
755-
pdf_model=pdf_model, replica_paths=replica_paths, check_freq=self.checkpoint_freq
755+
pdf_model=pdf_model,
756+
replica_paths=replica_paths,
757+
check_freq=self.checkpoint_freq,
758+
stopping_object=stopping_object,
756759
)
757760
callback_list.append(checpoint_callback)
758761

0 commit comments

Comments
 (0)