diff --git a/docs/source-pytorch/common/early_stopping.rst b/docs/source-pytorch/common/early_stopping.rst index 6cf111941ed97..235d3a6794527 100644 --- a/docs/source-pytorch/common/early_stopping.rst +++ b/docs/source-pytorch/common/early_stopping.rst @@ -1,6 +1,7 @@ .. testsetup:: * - from lightning.pytorch.callbacks.early_stopping import EarlyStopping + from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason + from lightning.pytorch import Trainer, LightningModule .. _early_stopping: @@ -71,6 +72,37 @@ Additional parameters that stop training at extreme points: - ``check_on_train_epoch_end``: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within training-specific hooks on epoch-level. +After training completes, you can programmatically check why early stopping occurred using the ``stopping_reason`` +attribute, which returns an ``EarlyStoppingReason`` enum value. + +.. code-block:: python + + from lightning.pytorch.callbacks import EarlyStopping + from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason + + early_stopping = EarlyStopping(monitor="val_loss", patience=3) + trainer = Trainer(callbacks=[early_stopping]) + trainer.fit(model) + + # Check why training stopped + if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED: + print("Training stopped due to patience exhaustion") + elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD: + print("Training stopped due to reaching stopping threshold") + elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED: + print("Training completed normally without early stopping") + + # Access human-readable message + if early_stopping.stopping_reason_message: + print(f"Details: {early_stopping.stopping_reason_message}") + +The available stopping reasons are: + +- ``NOT_STOPPED``: Training completed normally without early stopping +- ``STOPPING_THRESHOLD``: Training stopped because the monitored metric reached the stopping threshold +- ``DIVERGENCE_THRESHOLD``: Training stopped because the monitored metric exceeded the divergence threshold +- ``PATIENCE_EXHAUSTED``: Training stopped because the metric didn't improve for the specified patience +- ``NON_FINITE_METRIC``: Training stopped because the monitored metric became NaN or infinite In case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` and change where it is called: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index ddc12a92e9f56..d001fa4b25412 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071)) +- Added attributes to access stopping reason in `EarlyStopping` callback ([#21188](https://github.com/Lightning-AI/pytorch-lightning/pull/21188)) + + - Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236)) diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d108894f614e6..7569705b9d4ea 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -20,6 +20,7 @@ """ import logging +from enum import Enum from typing import Any, Callable, Optional import torch @@ -34,6 +35,16 @@ log = logging.getLogger(__name__) +class EarlyStoppingReason(Enum): + """Enum for early stopping reasons.""" + + NOT_STOPPED = 0 + STOPPING_THRESHOLD = 1 + DIVERGENCE_THRESHOLD = 2 + PATIENCE_EXHAUSTED = 3 + NON_FINITE_METRIC = 4 + + class EarlyStopping(Callback): r"""Monitor a metric and stop training when it stops improving. @@ -65,6 +76,11 @@ class EarlyStopping(Callback): If this is ``False``, then the check runs at the end of the validation. log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process. + Attributes: + stopped_epoch: The epoch at which training was stopped. 0 if training was not stopped. + stopping_reason: An ``EarlyStoppingReason`` enum indicating why training was stopped. + stopping_reason_message: A human-readable message explaining why training was stopped. + Raises: MisconfigurationException: If ``mode`` is none of ``"min"`` or ``"max"``. @@ -75,8 +91,12 @@ class EarlyStopping(Callback): >>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import EarlyStopping + >>> from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) + >>> # After training... + >>> if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED: + ... print("Training stopped due to patience exhaustion") .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the following arguments: @@ -117,6 +137,8 @@ def __init__( self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 + self.stopping_reason = EarlyStoppingReason.NOT_STOPPED + self.stopping_reason_message: Optional[str] = None self._check_on_train_epoch_end = check_on_train_epoch_end self.log_rank_zero_only = log_rank_zero_only @@ -169,6 +191,8 @@ def state_dict(self) -> dict[str, Any]: "stopped_epoch": self.stopped_epoch, "best_score": self.best_score, "patience": self.patience, + "stopping_reason": self.stopping_reason.value, + "stopping_reason_message": self.stopping_reason_message, } @override @@ -177,6 +201,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.stopped_epoch = state_dict["stopped_epoch"] self.best_score = state_dict["best_score"] self.patience = state_dict["patience"] + stopping_reason_value = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED.value) + self.stopping_reason = EarlyStoppingReason(stopping_reason_value) + self.stopping_reason_message = state_dict.get("stopping_reason_message") def _should_skip_check(self, trainer: "pl.Trainer") -> bool: from lightning.pytorch.trainer.states import TrainerFn @@ -212,6 +239,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: trainer.should_stop = trainer.should_stop or should_stop if should_stop: self.stopped_epoch = trainer.current_epoch + self.stopping_reason_message = reason if reason and self.verbose: self._log_info(trainer, reason, self.log_rank_zero_only) @@ -220,12 +248,14 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s reason = None if self.check_finite and not torch.isfinite(current): should_stop = True + self.stopping_reason = EarlyStoppingReason.NON_FINITE_METRIC reason = ( f"Monitored metric {self.monitor} = {current} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." ) elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): should_stop = True + self.stopping_reason = EarlyStoppingReason.STOPPING_THRESHOLD reason = ( "Stopping threshold reached:" f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}." @@ -233,6 +263,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s ) elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): should_stop = True + self.stopping_reason = EarlyStoppingReason.DIVERGENCE_THRESHOLD reason = ( "Divergence threshold reached:" f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." @@ -247,6 +278,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s self.wait_count += 1 if self.wait_count >= self.patience: should_stop = True + self.stopping_reason = EarlyStoppingReason.PATIENCE_EXHAUSTED reason = ( f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records." f" Best score: {self.best_score:.3f}. Signaling Trainer to stop." diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 9a87b3daaad6e..ff65e08c8c01a 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging import math import os @@ -25,6 +26,7 @@ from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule @@ -505,3 +507,190 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex log_mock.assert_called_once_with(expected_log) else: log_mock.assert_not_called() + + +class ModelWithHighLoss(BoringModel): + def on_validation_epoch_end(self): + self.log("val_loss", 10.0) + + +class ModelWithDecreasingLoss(BoringModel): + def __init__(self): + super().__init__() + self.epoch_losses = [5.0, 3.0, 1.0, 0.5] + + def on_validation_epoch_end(self): + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1 + self.log("val_loss", loss) + + +class ModelWithIncreasingLoss(BoringModel): + def __init__(self): + super().__init__() + self.epoch_losses = [1.0, 2.0, 5.0, 10.0] + + def on_validation_epoch_end(self): + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 15.0 + self.log("val_loss", loss) + + +class ModelWithNaNLoss(BoringModel): + def __init__(self): + super().__init__() + self.epoch_losses = [1.0, 0.5, float("nan")] + + def on_validation_epoch_end(self): + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else float("nan") + self.log("val_loss", loss) + + +class ModelWithImprovingLoss(BoringModel): + def __init__(self): + super().__init__() + self.epoch_losses = [5.0, 4.0, 3.0, 2.0, 1.0] + + def on_validation_epoch_end(self): + loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1 + self.log("val_loss", loss) + + +@pytest.mark.parametrize( + ( + "model_cls", + "early_stopping_kwargs", + "trainer_kwargs", + "expected_reason", + "reason_message_substr", + "should_stop", + "state_dict_override", + ), + [ + # Patience exhausted + ( + ModelWithHighLoss, + {"monitor": "val_loss", "patience": 2, "verbose": True}, + {"max_epochs": 10, "enable_progress_bar": False}, + EarlyStoppingReason.PATIENCE_EXHAUSTED, + "did not improve", + True, + None, + ), + # Stopping threshold + ( + ModelWithDecreasingLoss, + {"monitor": "val_loss", "stopping_threshold": 0.6, "mode": "min", "verbose": True}, + {"max_epochs": 10, "enable_progress_bar": False}, + EarlyStoppingReason.STOPPING_THRESHOLD, + "Stopping threshold reached", + True, + None, + ), + # Divergence threshold + ( + ModelWithIncreasingLoss, + {"monitor": "val_loss", "divergence_threshold": 8.0, "mode": "min", "verbose": True}, + {"max_epochs": 10, "enable_progress_bar": False}, + EarlyStoppingReason.DIVERGENCE_THRESHOLD, + "Divergence threshold reached", + True, + None, + ), + # Non-finite metric + ( + ModelWithNaNLoss, + {"monitor": "val_loss", "check_finite": True, "verbose": True}, + {"max_epochs": 10, "enable_progress_bar": False}, + EarlyStoppingReason.NON_FINITE_METRIC, + "is not finite", + True, + None, + ), + # Not stopped (normal completion) + ( + ModelWithImprovingLoss, + {"monitor": "val_loss", "patience": 3, "verbose": True}, + {"max_epochs": 3, "enable_progress_bar": False}, + EarlyStoppingReason.NOT_STOPPED, + None, + False, + None, + ), + # State persistence + ( + None, + {"monitor": "val_loss", "patience": 3}, + {}, + EarlyStoppingReason.PATIENCE_EXHAUSTED, + "Test message", + None, + {"stopping_reason": EarlyStoppingReason.PATIENCE_EXHAUSTED, "stopping_reason_message": "Test message"}, + ), + # Backward compatibility (old state dict) + ( + None, + {"monitor": "val_loss", "patience": 3}, + {}, + EarlyStoppingReason.NOT_STOPPED, + None, + None, + { + "wait_count": 2, + "stopped_epoch": 5, + "best_score": torch.tensor(0.5), + "patience": 3, + }, + ), + ], +) +def test_early_stopping_reasons( + tmp_path, + model_cls, + early_stopping_kwargs, + trainer_kwargs, + expected_reason, + reason_message_substr, + should_stop, + state_dict_override, +): + """Test all early stopping reasons in a single parametrized test.""" + if state_dict_override is not None: + early_stopping = EarlyStopping(**early_stopping_kwargs) + if "stopping_reason" in state_dict_override: + # State persistence test + early_stopping.stopping_reason = state_dict_override["stopping_reason"] + early_stopping.stopping_reason_message = state_dict_override["stopping_reason_message"] + state_dict = early_stopping.state_dict() + new_early_stopping = EarlyStopping(**early_stopping_kwargs) + new_early_stopping.load_state_dict(state_dict) + assert new_early_stopping.stopping_reason == expected_reason + assert new_early_stopping.stopping_reason_message == reason_message_substr + else: + # Backward compatibility test + early_stopping.load_state_dict(copy.deepcopy(state_dict_override)) + assert early_stopping.stopping_reason == expected_reason + assert early_stopping.stopping_reason_message is None + assert early_stopping.wait_count == state_dict_override["wait_count"] + assert early_stopping.stopped_epoch == state_dict_override["stopped_epoch"] + return + + # All other tests + model = model_cls() + early_stopping = EarlyStopping(**early_stopping_kwargs) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[early_stopping], + **trainer_kwargs, + ) + trainer.fit(model) + + assert early_stopping.stopping_reason == expected_reason + if reason_message_substr is not None: + assert early_stopping.stopping_reason_message is not None + assert reason_message_substr in early_stopping.stopping_reason_message + else: + assert early_stopping.stopping_reason_message is None + if should_stop is not None: + if should_stop: + assert early_stopping.stopped_epoch > 0 + else: + assert early_stopping.stopped_epoch == 0