20
20
"""
21
21
22
22
import logging
23
+ from enum import Enum
23
24
from typing import Any , Callable , Optional
24
25
25
26
import torch
34
35
log = logging .getLogger (__name__ )
35
36
36
37
38
+ class EarlyStoppingReason (Enum ):
39
+ """Enum for early stopping reasons."""
40
+
41
+ NOT_STOPPED = 0
42
+ STOPPING_THRESHOLD = 1
43
+ DIVERGENCE_THRESHOLD = 2
44
+ PATIENCE_EXHAUSTED = 3
45
+ NON_FINITE_METRIC = 4
46
+
47
+
37
48
class EarlyStopping (Callback ):
38
49
r"""Monitor a metric and stop training when it stops improving.
39
50
@@ -65,6 +76,11 @@ class EarlyStopping(Callback):
65
76
If this is ``False``, then the check runs at the end of the validation.
66
77
log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process.
67
78
79
+ Attributes:
80
+ stopped_epoch: The epoch at which training was stopped. 0 if training was not stopped.
81
+ stopping_reason: An ``EarlyStoppingReason`` enum indicating why training was stopped.
82
+ stopping_reason_message: A human-readable message explaining why training was stopped.
83
+
68
84
Raises:
69
85
MisconfigurationException:
70
86
If ``mode`` is none of ``"min"`` or ``"max"``.
@@ -74,9 +90,12 @@ class EarlyStopping(Callback):
74
90
Example::
75
91
76
92
>>> from lightning.pytorch import Trainer
77
- >>> from lightning.pytorch.callbacks import EarlyStopping
93
+ >>> from lightning.pytorch.callbacks import EarlyStopping, EarlyStoppingReason
78
94
>>> early_stopping = EarlyStopping('val_loss')
79
95
>>> trainer = Trainer(callbacks=[early_stopping])
96
+ >>> # After training...
97
+ >>> if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
98
+ ... print("Training stopped due to patience exhaustion")
80
99
81
100
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
82
101
following arguments:
@@ -117,6 +136,8 @@ def __init__(
117
136
self .divergence_threshold = divergence_threshold
118
137
self .wait_count = 0
119
138
self .stopped_epoch = 0
139
+ self .stopping_reason = EarlyStoppingReason .NOT_STOPPED
140
+ self .stopping_reason_message = None
120
141
self ._check_on_train_epoch_end = check_on_train_epoch_end
121
142
self .log_rank_zero_only = log_rank_zero_only
122
143
@@ -169,6 +190,8 @@ def state_dict(self) -> dict[str, Any]:
169
190
"stopped_epoch" : self .stopped_epoch ,
170
191
"best_score" : self .best_score ,
171
192
"patience" : self .patience ,
193
+ "stopping_reason" : self .stopping_reason ,
194
+ "stopping_reason_message" : self .stopping_reason_message ,
172
195
}
173
196
174
197
@override
@@ -177,6 +200,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
177
200
self .stopped_epoch = state_dict ["stopped_epoch" ]
178
201
self .best_score = state_dict ["best_score" ]
179
202
self .patience = state_dict ["patience" ]
203
+ # For backward compatibility, set defaults if not present
204
+ self .stopping_reason = state_dict .get ("stopping_reason" , EarlyStoppingReason .NOT_STOPPED )
205
+ self .stopping_reason_message = state_dict .get ("stopping_reason_message" )
180
206
181
207
def _should_skip_check (self , trainer : "pl.Trainer" ) -> bool :
182
208
from lightning .pytorch .trainer .states import TrainerFn
@@ -212,6 +238,8 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
212
238
trainer .should_stop = trainer .should_stop or should_stop
213
239
if should_stop :
214
240
self .stopped_epoch = trainer .current_epoch
241
+ # Store the stopping reason message
242
+ self .stopping_reason_message = reason
215
243
if reason and self .verbose :
216
244
self ._log_info (trainer , reason , self .log_rank_zero_only )
217
245
@@ -220,19 +248,22 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
220
248
reason = None
221
249
if self .check_finite and not torch .isfinite (current ):
222
250
should_stop = True
251
+ self .stopping_reason = EarlyStoppingReason .NON_FINITE_METRIC
223
252
reason = (
224
253
f"Monitored metric { self .monitor } = { current } is not finite."
225
254
f" Previous best value was { self .best_score :.3f} . Signaling Trainer to stop."
226
255
)
227
256
elif self .stopping_threshold is not None and self .monitor_op (current , self .stopping_threshold ):
228
257
should_stop = True
258
+ self .stopping_reason = EarlyStoppingReason .STOPPING_THRESHOLD
229
259
reason = (
230
260
"Stopping threshold reached:"
231
261
f" { self .monitor } = { current } { self .order_dict [self .mode ]} { self .stopping_threshold } ."
232
262
" Signaling Trainer to stop."
233
263
)
234
264
elif self .divergence_threshold is not None and self .monitor_op (- current , - self .divergence_threshold ):
235
265
should_stop = True
266
+ self .stopping_reason = EarlyStoppingReason .DIVERGENCE_THRESHOLD
236
267
reason = (
237
268
"Divergence threshold reached:"
238
269
f" { self .monitor } = { current } { self .order_dict [self .mode ]} { self .divergence_threshold } ."
@@ -247,6 +278,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
247
278
self .wait_count += 1
248
279
if self .wait_count >= self .patience :
249
280
should_stop = True
281
+ self .stopping_reason = EarlyStoppingReason .PATIENCE_EXHAUSTED
250
282
reason = (
251
283
f"Monitored metric { self .monitor } did not improve in the last { self .wait_count } records."
252
284
f" Best score: { self .best_score :.3f} . Signaling Trainer to stop."
0 commit comments