|
1 | 1 | .. testsetup:: *
|
2 | 2 |
|
3 |
| - from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
| 3 | + from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason |
| 4 | + from lightning.pytorch import Trainer, LightningModule |
4 | 5 |
|
5 | 6 | .. _early_stopping:
|
6 | 7 |
|
@@ -71,8 +72,37 @@ Additional parameters that stop training at extreme points:
|
71 | 72 | - ``check_on_train_epoch_end``: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within
|
72 | 73 | training-specific hooks on epoch-level.
|
73 | 74 |
|
| 75 | +**Accessing Stopping Reason** |
| 76 | + |
| 77 | +After training completes, you can programmatically check why early stopping occurred using the ``stopping_reason`` attribute, which returns an ``EarlyStoppingReason`` enum value. |
| 78 | + |
| 79 | +.. testcode:: |
| 80 | + |
| 81 | + from lightning.pytorch.callbacks import EarlyStopping, EarlyStoppingReason |
| 82 | + |
| 83 | + early_stopping = EarlyStopping(monitor="val_loss", patience=3) |
| 84 | + trainer = Trainer(callbacks=[early_stopping]) |
| 85 | + trainer.fit(model) |
74 | 86 |
|
75 |
| -In case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` |
| 87 | + # Check why training stopped |
| 88 | + if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED: |
| 89 | + print("Training stopped due to patience exhaustion") |
| 90 | + elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD: |
| 91 | + print("Training stopped due to reaching stopping threshold") |
| 92 | + elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED: |
| 93 | + print("Training completed normally without early stopping") |
| 94 | + |
| 95 | + # Access human-readable message |
| 96 | + if early_stopping.stopping_reason_message: |
| 97 | + print(f"Details: {early_stopping.stopping_reason_message}") |
| 98 | + |
| 99 | +The available stopping reasons are: |
| 100 | + |
| 101 | +- ``NOT_STOPPED``: Training completed normally without early stopping |
| 102 | +- ``STOPPING_THRESHOLD``: Training stopped because the monitored metric reached the stopping threshold |
| 103 | +- ``DIVERGENCE_THRESHOLD``: Training stopped because the monitored metric exceeded the divergence threshold |
| 104 | +- ``PATIENCE_EXHAUSTED``: Training stopped because the metric didn't improve for the specified patience |
| 105 | +- ``NON_FINITE_METRIC``: Training stopped because the monitored metric became NaN or infiniteIn case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` |
76 | 106 | and change where it is called:
|
77 | 107 |
|
78 | 108 | .. testcode::
|
|
0 commit comments