Skip to content
Merged
34 changes: 32 additions & 2 deletions docs/source-pytorch/common/early_stopping.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.. testsetup:: *

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason
from lightning.pytorch import Trainer, LightningModule

.. _early_stopping:

Expand Down Expand Up @@ -71,8 +72,37 @@ Additional parameters that stop training at extreme points:
- ``check_on_train_epoch_end``: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within
training-specific hooks on epoch-level.

**Accessing Stopping Reason**

After training completes, you can programmatically check why early stopping occurred using the ``stopping_reason`` attribute, which returns an ``EarlyStoppingReason`` enum value.

.. testcode::

from lightning.pytorch.callbacks import EarlyStopping, EarlyStoppingReason

early_stopping = EarlyStopping(monitor="val_loss", patience=3)
trainer = Trainer(callbacks=[early_stopping])
trainer.fit(model)

In case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`
# Check why training stopped
if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
print("Training stopped due to patience exhaustion")
elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
print("Training stopped due to reaching stopping threshold")
elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
print("Training completed normally without early stopping")

# Access human-readable message
if early_stopping.stopping_reason_message:
print(f"Details: {early_stopping.stopping_reason_message}")

The available stopping reasons are:

- ``NOT_STOPPED``: Training completed normally without early stopping
- ``STOPPING_THRESHOLD``: Training stopped because the monitored metric reached the stopping threshold
- ``DIVERGENCE_THRESHOLD``: Training stopped because the monitored metric exceeded the divergence threshold
- ``PATIENCE_EXHAUSTED``: Training stopped because the metric didn't improve for the specified patience
- ``NON_FINITE_METRIC``: Training stopped because the monitored metric became NaN or infiniteIn case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`
and change where it is called:

.. testcode::
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))


- Added attributes to access stopping reason in `EarlyStopping` callback ([#21188](https://github.com/Lightning-AI/pytorch-lightning/pull/21188))


### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
Expand Down
34 changes: 33 additions & 1 deletion src/lightning/pytorch/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

import logging
from enum import Enum
from typing import Any, Callable, Optional

import torch
Expand All @@ -34,6 +35,16 @@
log = logging.getLogger(__name__)


class EarlyStoppingReason(Enum):
"""Enum for early stopping reasons."""

NOT_STOPPED = 0
STOPPING_THRESHOLD = 1
DIVERGENCE_THRESHOLD = 2
PATIENCE_EXHAUSTED = 3
NON_FINITE_METRIC = 4


class EarlyStopping(Callback):
r"""Monitor a metric and stop training when it stops improving.

Expand Down Expand Up @@ -65,6 +76,11 @@ class EarlyStopping(Callback):
If this is ``False``, then the check runs at the end of the validation.
log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process.

Attributes:
stopped_epoch: The epoch at which training was stopped. 0 if training was not stopped.
stopping_reason: An ``EarlyStoppingReason`` enum indicating why training was stopped.
stopping_reason_message: A human-readable message explaining why training was stopped.

Raises:
MisconfigurationException:
If ``mode`` is none of ``"min"`` or ``"max"``.
Expand All @@ -74,9 +90,12 @@ class EarlyStopping(Callback):
Example::

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import EarlyStopping
>>> from lightning.pytorch.callbacks import EarlyStopping, EarlyStoppingReason
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])
>>> # After training...
>>> if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
... print("Training stopped due to patience exhaustion")

.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
following arguments:
Expand Down Expand Up @@ -117,6 +136,8 @@ def __init__(
self.divergence_threshold = divergence_threshold
self.wait_count = 0
self.stopped_epoch = 0
self.stopping_reason = EarlyStoppingReason.NOT_STOPPED
self.stopping_reason_message = None
self._check_on_train_epoch_end = check_on_train_epoch_end
self.log_rank_zero_only = log_rank_zero_only

Expand Down Expand Up @@ -169,6 +190,8 @@ def state_dict(self) -> dict[str, Any]:
"stopped_epoch": self.stopped_epoch,
"best_score": self.best_score,
"patience": self.patience,
"stopping_reason": self.stopping_reason,
"stopping_reason_message": self.stopping_reason_message,
}

@override
Expand All @@ -177,6 +200,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.stopped_epoch = state_dict["stopped_epoch"]
self.best_score = state_dict["best_score"]
self.patience = state_dict["patience"]
# For backward compatibility, set defaults if not present
self.stopping_reason = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED)
self.stopping_reason_message = state_dict.get("stopping_reason_message")

def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
from lightning.pytorch.trainer.states import TrainerFn
Expand Down Expand Up @@ -212,6 +238,8 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
# Store the stopping reason message
self.stopping_reason_message = reason
if reason and self.verbose:
self._log_info(trainer, reason, self.log_rank_zero_only)

Expand All @@ -220,19 +248,22 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
reason = None
if self.check_finite and not torch.isfinite(current):
should_stop = True
self.stopping_reason = EarlyStoppingReason.NON_FINITE_METRIC
reason = (
f"Monitored metric {self.monitor} = {current} is not finite."
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
)
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
should_stop = True
self.stopping_reason = EarlyStoppingReason.STOPPING_THRESHOLD
reason = (
"Stopping threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
" Signaling Trainer to stop."
)
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
should_stop = True
self.stopping_reason = EarlyStoppingReason.DIVERGENCE_THRESHOLD
reason = (
"Divergence threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
Expand All @@ -247,6 +278,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
self.wait_count += 1
if self.wait_count >= self.patience:
should_stop = True
self.stopping_reason = EarlyStoppingReason.PATIENCE_EXHAUSTED
reason = (
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
Expand Down
189 changes: 189 additions & 0 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
import os
Expand All @@ -25,6 +26,7 @@

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -505,3 +507,190 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex
log_mock.assert_called_once_with(expected_log)
else:
log_mock.assert_not_called()


class ModelWithHighLoss(BoringModel):
def on_validation_epoch_end(self):
self.log("val_loss", 10.0)


class ModelWithDecreasingLoss(BoringModel):
def __init__(self):
super().__init__()
self.epoch_losses = [5.0, 3.0, 1.0, 0.5]

def on_validation_epoch_end(self):
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1
self.log("val_loss", loss)


class ModelWithIncreasingLoss(BoringModel):
def __init__(self):
super().__init__()
self.epoch_losses = [1.0, 2.0, 5.0, 10.0]

def on_validation_epoch_end(self):
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 15.0
self.log("val_loss", loss)


class ModelWithNaNLoss(BoringModel):
def __init__(self):
super().__init__()
self.epoch_losses = [1.0, 0.5, float("nan")]

def on_validation_epoch_end(self):
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else float("nan")
self.log("val_loss", loss)


class ModelWithImprovingLoss(BoringModel):
def __init__(self):
super().__init__()
self.epoch_losses = [5.0, 4.0, 3.0, 2.0, 1.0]

def on_validation_epoch_end(self):
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1
self.log("val_loss", loss)


@pytest.mark.parametrize(
(
"model_cls",
"early_stopping_kwargs",
"trainer_kwargs",
"expected_reason",
"reason_message_substr",
"should_stop",
"state_dict_override",
),
[
# Patience exhausted
(
ModelWithHighLoss,
{"monitor": "val_loss", "patience": 2, "verbose": True},
{"max_epochs": 10, "enable_progress_bar": False},
EarlyStoppingReason.PATIENCE_EXHAUSTED,
"did not improve",
True,
None,
),
# Stopping threshold
(
ModelWithDecreasingLoss,
{"monitor": "val_loss", "stopping_threshold": 0.6, "mode": "min", "verbose": True},
{"max_epochs": 10, "enable_progress_bar": False},
EarlyStoppingReason.STOPPING_THRESHOLD,
"Stopping threshold reached",
True,
None,
),
# Divergence threshold
(
ModelWithIncreasingLoss,
{"monitor": "val_loss", "divergence_threshold": 8.0, "mode": "min", "verbose": True},
{"max_epochs": 10, "enable_progress_bar": False},
EarlyStoppingReason.DIVERGENCE_THRESHOLD,
"Divergence threshold reached",
True,
None,
),
# Non-finite metric
(
ModelWithNaNLoss,
{"monitor": "val_loss", "check_finite": True, "verbose": True},
{"max_epochs": 10, "enable_progress_bar": False},
EarlyStoppingReason.NON_FINITE_METRIC,
"is not finite",
True,
None,
),
# Not stopped (normal completion)
(
ModelWithImprovingLoss,
{"monitor": "val_loss", "patience": 3, "verbose": True},
{"max_epochs": 3, "enable_progress_bar": False},
EarlyStoppingReason.NOT_STOPPED,
None,
False,
None,
),
# State persistence
(
None,
{"monitor": "val_loss", "patience": 3},
{},
EarlyStoppingReason.PATIENCE_EXHAUSTED,
"Test message",
None,
{"stopping_reason": EarlyStoppingReason.PATIENCE_EXHAUSTED, "stopping_reason_message": "Test message"},
),
# Backward compatibility (old state dict)
(
None,
{"monitor": "val_loss", "patience": 3},
{},
EarlyStoppingReason.NOT_STOPPED,
None,
None,
{
"wait_count": 2,
"stopped_epoch": 5,
"best_score": torch.tensor(0.5),
"patience": 3,
},
),
],
)
def test_early_stopping_reasons(
tmp_path,
model_cls,
early_stopping_kwargs,
trainer_kwargs,
expected_reason,
reason_message_substr,
should_stop,
state_dict_override,
):
"""Test all early stopping reasons in a single parametrized test."""
if state_dict_override is not None:
early_stopping = EarlyStopping(**early_stopping_kwargs)
if "stopping_reason" in state_dict_override:
# State persistence test
early_stopping.stopping_reason = state_dict_override["stopping_reason"]
early_stopping.stopping_reason_message = state_dict_override["stopping_reason_message"]
state_dict = early_stopping.state_dict()
new_early_stopping = EarlyStopping(**early_stopping_kwargs)
new_early_stopping.load_state_dict(state_dict)
assert new_early_stopping.stopping_reason == expected_reason
assert new_early_stopping.stopping_reason_message == reason_message_substr
else:
# Backward compatibility test
early_stopping.load_state_dict(copy.deepcopy(state_dict_override))
assert early_stopping.stopping_reason == expected_reason
assert early_stopping.stopping_reason_message is None
assert early_stopping.wait_count == state_dict_override["wait_count"]
assert early_stopping.stopped_epoch == state_dict_override["stopped_epoch"]
return

# All other tests
model = model_cls()
early_stopping = EarlyStopping(**early_stopping_kwargs)
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[early_stopping],
**trainer_kwargs,
)
trainer.fit(model)

assert early_stopping.stopping_reason == expected_reason
if reason_message_substr is not None:
assert early_stopping.stopping_reason_message is not None
assert reason_message_substr in early_stopping.stopping_reason_message
else:
assert early_stopping.stopping_reason_message is None
if should_stop is not None:
if should_stop:
assert early_stopping.stopped_epoch > 0
else:
assert early_stopping.stopped_epoch == 0
Loading