Skip to content

Commit ebd3e3a

Browse files
committed
change folder structure
1 parent f6b7194 commit ebd3e3a

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

n3fit/src/n3fit/model_trainer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections import namedtuple
1313
from itertools import zip_longest
1414
import logging
15+
import pickle
1516

1617
import numpy as np
1718

@@ -742,7 +743,10 @@ def _train_and_fit(self, training_model, stopping_object, epochs=100) -> bool:
742743

743744
if self.save_checkpoints:
744745
pdf_model = training_model.get_layer("PDFs")
745-
replica_paths = [self.replica_path / f"replica_{r}" for r in self.replicas]
746+
# Save parameters where colibri will look for checkpoints
747+
replica_paths = [
748+
self.replica_path.parent / f"fit_replicas/replica_{r}" for r in self.replicas
749+
]
746750
checpoint_callback = callbacks.StoreCallback(
747751
pdf_model=pdf_model, replica_paths=replica_paths, check_freq=self.checkpoint_freq
748752
)
@@ -961,9 +965,8 @@ def hyperparametrizable(self, params):
961965
"layer_type": params["layer_type"],
962966
}
963967
state = {"_init_args": _init_args}
964-
import pickle
965968

966-
with open(self.replica_path / "pdf_model.pkl", "wb") as file:
969+
with open(self.replica_path.parent / "pdf_model.pkl", "wb") as file:
967970
pickle.dump(state, file)
968971

969972
### Training loop

0 commit comments

Comments
 (0)