Skip to content

Commit 46d8fa2

Browse files
committed
add public reason api
1 parent cd30ce4 commit 46d8fa2

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

src/lightning/pytorch/callbacks/early_stopping.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import logging
23+
from enum import Enum
2324
from typing import Any, Callable, Optional
2425

2526
import torch
@@ -34,6 +35,16 @@
3435
log = logging.getLogger(__name__)
3536

3637

38+
class EarlyStoppingReason(Enum):
39+
"""Enum for early stopping reasons."""
40+
41+
NOT_STOPPED = 0
42+
STOPPING_THRESHOLD = 1
43+
DIVERGENCE_THRESHOLD = 2
44+
PATIENCE_EXHAUSTED = 3
45+
NON_FINITE_METRIC = 4
46+
47+
3748
class EarlyStopping(Callback):
3849
r"""Monitor a metric and stop training when it stops improving.
3950
@@ -65,6 +76,11 @@ class EarlyStopping(Callback):
6576
If this is ``False``, then the check runs at the end of the validation.
6677
log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process.
6778
79+
Attributes:
80+
stopped_epoch: The epoch at which training was stopped. 0 if training was not stopped.
81+
stopping_reason: An ``EarlyStoppingReason`` enum indicating why training was stopped.
82+
stopping_reason_message: A human-readable message explaining why training was stopped.
83+
6884
Raises:
6985
MisconfigurationException:
7086
If ``mode`` is none of ``"min"`` or ``"max"``.
@@ -74,9 +90,12 @@ class EarlyStopping(Callback):
7490
Example::
7591
7692
>>> from lightning.pytorch import Trainer
77-
>>> from lightning.pytorch.callbacks import EarlyStopping
93+
>>> from lightning.pytorch.callbacks import EarlyStopping, EarlyStoppingReason
7894
>>> early_stopping = EarlyStopping('val_loss')
7995
>>> trainer = Trainer(callbacks=[early_stopping])
96+
>>> # After training...
97+
>>> if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
98+
... print("Training stopped due to patience exhaustion")
8099
81100
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
82101
following arguments:
@@ -117,6 +136,8 @@ def __init__(
117136
self.divergence_threshold = divergence_threshold
118137
self.wait_count = 0
119138
self.stopped_epoch = 0
139+
self.stopping_reason = EarlyStoppingReason.NOT_STOPPED
140+
self.stopping_reason_message = None
120141
self._check_on_train_epoch_end = check_on_train_epoch_end
121142
self.log_rank_zero_only = log_rank_zero_only
122143

@@ -169,6 +190,8 @@ def state_dict(self) -> dict[str, Any]:
169190
"stopped_epoch": self.stopped_epoch,
170191
"best_score": self.best_score,
171192
"patience": self.patience,
193+
"stopping_reason": self.stopping_reason,
194+
"stopping_reason_message": self.stopping_reason_message,
172195
}
173196

174197
@override
@@ -177,6 +200,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
177200
self.stopped_epoch = state_dict["stopped_epoch"]
178201
self.best_score = state_dict["best_score"]
179202
self.patience = state_dict["patience"]
203+
# For backward compatibility, set defaults if not present
204+
self.stopping_reason = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED)
205+
self.stopping_reason_message = state_dict.get("stopping_reason_message")
180206

181207
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
182208
from lightning.pytorch.trainer.states import TrainerFn
@@ -212,6 +238,8 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
212238
trainer.should_stop = trainer.should_stop or should_stop
213239
if should_stop:
214240
self.stopped_epoch = trainer.current_epoch
241+
# Store the stopping reason message
242+
self.stopping_reason_message = reason
215243
if reason and self.verbose:
216244
self._log_info(trainer, reason, self.log_rank_zero_only)
217245

@@ -220,19 +248,22 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
220248
reason = None
221249
if self.check_finite and not torch.isfinite(current):
222250
should_stop = True
251+
self.stopping_reason = EarlyStoppingReason.NON_FINITE_METRIC
223252
reason = (
224253
f"Monitored metric {self.monitor} = {current} is not finite."
225254
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
226255
)
227256
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
228257
should_stop = True
258+
self.stopping_reason = EarlyStoppingReason.STOPPING_THRESHOLD
229259
reason = (
230260
"Stopping threshold reached:"
231261
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
232262
" Signaling Trainer to stop."
233263
)
234264
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
235265
should_stop = True
266+
self.stopping_reason = EarlyStoppingReason.DIVERGENCE_THRESHOLD
236267
reason = (
237268
"Divergence threshold reached:"
238269
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
@@ -247,6 +278,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
247278
self.wait_count += 1
248279
if self.wait_count >= self.patience:
249280
should_stop = True
281+
self.stopping_reason = EarlyStoppingReason.PATIENCE_EXHAUSTED
250282
reason = (
251283
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
252284
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."

0 commit comments

Comments
 (0)