Skip to content

Commit a32daab

Browse files
committed
Remove stopping and add would stop
1 parent 1f68580 commit a32daab

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

n3fit/src/n3fit/io/writer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,18 @@ def write_data(self, save_path, fitname, weights_name):
279279
replica_path.mkdir(exist_ok=True, parents=True)
280280

281281
self._write_chi2s(replica_path / "chi2exps.log")
282+
self._write_would_stop_epoch(replica_path / "would_stop_epoch.txt")
282283
self._write_metadata_json(i, replica_path / f"{fitname}.json")
283284
self._export_pdf_grid(i, replica_path / f"{fitname}.exportgrid")
284285
if weights_name:
285286
self._write_weights(i, replica_path / f"{weights_name}")
286287

288+
def _write_would_stop_epoch(self, out_path):
289+
epoch = self.stopping_object.would_stop_epoch
290+
with open(out_path, "w", encoding="utf-8") as f:
291+
f.write(str(epoch) if epoch is not None else "None")
292+
f.write("\n")
293+
287294
def _write_chi2s(self, out_path):
288295
# Note: same for all replicas, unless run separately
289296
chi2_log = self.stopping_object.chi2exps_json()

n3fit/src/n3fit/model_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
save_checkpoints=False,
117117
replica_path=None,
118118
checkpoint_freq=100,
119+
dont_stop=False,
119120
):
120121
"""
121122
Parameters
@@ -163,6 +164,8 @@ def __init__(
163164
root path for all replicas.
164165
checkpoint_freq: int
165166
frequency (in epochs) at which to save checkpoints. Only relevant if `save_checkpoints` is True.
167+
dont_stop: bool
168+
whether to disable the stopping mechanism, i.e. to run for all epochs regardless of the validation chi2
166169
"""
167170
# Save all input information
168171
self.exp_info = list(exp_info)
@@ -179,6 +182,7 @@ def __init__(
179182
self.lux_params = lux_params
180183
self.replicas = replicas
181184
self.experiments_data = experiments_data
185+
self.dont_stop = dont_stop
182186

183187
# Checkpointing options
184188
self.save_checkpoints = save_checkpoints
@@ -1035,6 +1039,7 @@ def hyperparametrizable(self, params):
10351039
stopping_patience=stopping_epochs,
10361040
threshold_positivity=threshold_pos,
10371041
threshold_chi2=threshold_chi2,
1042+
dont_stop=self.dont_stop,
10381043
)
10391044

10401045
# Compile each of the models with the right parameters

n3fit/src/n3fit/performfit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def performfit(
4444
parallel_models=True,
4545
save_checkpoints=False,
4646
checkpoint_freq=100,
47+
dont_stop=False,
4748
):
4849
"""
4950
This action will (upon having read a validcard) process a full PDF fit
@@ -204,6 +205,7 @@ def performfit(
204205
save_checkpoints=save_checkpoints,
205206
replica_path=replica_path,
206207
checkpoint_freq=checkpoint_freq,
208+
dont_stop=dont_stop,
207209
)
208210

209211
# This is just to give a descriptive name to the fit function

n3fit/src/n3fit/stopping.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def __init__(
345345

346346
self._dont_stop = dont_stop
347347
self._stop_now = False
348+
self._would_stop_epoch = None
348349
self.stopping_patience = stopping_patience
349350
self.total_epochs = total_epochs
350351

@@ -481,7 +482,20 @@ def make_stop(self):
481482
and reload the history to the point of the best model if any
482483
"""
483484
self._stop_now = True
484-
self._restore_best_weights()
485+
if self._would_stop_epoch is None:
486+
# final_epoch is the last registered epoch (0-indexed); +1 to match stop_epoch convention
487+
self._would_stop_epoch = (
488+
-1 if self._history.final_epoch is None else self._history.final_epoch + 1
489+
)
490+
if not self._dont_stop:
491+
self._restore_best_weights()
492+
493+
@property
494+
def would_stop_epoch(self):
495+
"""Epoch at which early stopping would have triggered.
496+
Returns None if stopping never triggered (fit converged within total_epochs).
497+
When dont_stop=False this equals stop_epoch."""
498+
return self._would_stop_epoch
485499

486500
def _restore_best_weights(self):
487501
for i_replica, weights in enumerate(self._best_weights):

0 commit comments

Comments
 (0)