|
1 | 1 | .. testsetup:: *
|
2 | 2 |
|
3 |
| - from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
| 3 | + from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason |
| 4 | + from lightning.pytorch import Trainer, LightningModule |
4 | 5 |
|
5 | 6 | .. _early_stopping:
|
6 | 7 |
|
@@ -71,6 +72,37 @@ Additional parameters that stop training at extreme points:
|
71 | 72 | - ``check_on_train_epoch_end``: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within
|
72 | 73 | training-specific hooks on epoch-level.
|
73 | 74 |
|
| 75 | +After training completes, you can programmatically check why early stopping occurred using the ``stopping_reason`` |
| 76 | +attribute, which returns an ``EarlyStoppingReason`` enum value. |
| 77 | + |
| 78 | +.. code-block:: python |
| 79 | +
|
| 80 | + from lightning.pytorch.callbacks import EarlyStopping |
| 81 | + from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason |
| 82 | +
|
| 83 | + early_stopping = EarlyStopping(monitor="val_loss", patience=3) |
| 84 | + trainer = Trainer(callbacks=[early_stopping]) |
| 85 | + trainer.fit(model) |
| 86 | +
|
| 87 | + # Check why training stopped |
| 88 | + if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED: |
| 89 | + print("Training stopped due to patience exhaustion") |
| 90 | + elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD: |
| 91 | + print("Training stopped due to reaching stopping threshold") |
| 92 | + elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED: |
| 93 | + print("Training completed normally without early stopping") |
| 94 | +
|
| 95 | + # Access human-readable message |
| 96 | + if early_stopping.stopping_reason_message: |
| 97 | + print(f"Details: {early_stopping.stopping_reason_message}") |
| 98 | +
|
| 99 | +The available stopping reasons are: |
| 100 | + |
| 101 | +- ``NOT_STOPPED``: Training completed normally without early stopping |
| 102 | +- ``STOPPING_THRESHOLD``: Training stopped because the monitored metric reached the stopping threshold |
| 103 | +- ``DIVERGENCE_THRESHOLD``: Training stopped because the monitored metric exceeded the divergence threshold |
| 104 | +- ``PATIENCE_EXHAUSTED``: Training stopped because the metric didn't improve for the specified patience |
| 105 | +- ``NON_FINITE_METRIC``: Training stopped because the monitored metric became NaN or infinite |
74 | 106 |
|
75 | 107 | In case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`
|
76 | 108 | and change where it is called:
|
|
0 commit comments