Skip to content

Commit d5431d4

Browse files
authored
Merge branch 'master' into weights-only-compatibility
2 parents 0685799 + 6989e15 commit d5431d4

File tree

15 files changed

+666
-27
lines changed

15 files changed

+666
-27
lines changed

docs/source-pytorch/common/early_stopping.rst

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.. testsetup:: *
22

3-
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
3+
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason
4+
from lightning.pytorch import Trainer, LightningModule
45

56
.. _early_stopping:
67

@@ -71,6 +72,37 @@ Additional parameters that stop training at extreme points:
7172
- ``check_on_train_epoch_end``: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within
7273
training-specific hooks on epoch-level.
7374

75+
After training completes, you can programmatically check why early stopping occurred using the ``stopping_reason``
76+
attribute, which returns an ``EarlyStoppingReason`` enum value.
77+
78+
.. code-block:: python
79+
80+
from lightning.pytorch.callbacks import EarlyStopping
81+
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
82+
83+
early_stopping = EarlyStopping(monitor="val_loss", patience=3)
84+
trainer = Trainer(callbacks=[early_stopping])
85+
trainer.fit(model)
86+
87+
# Check why training stopped
88+
if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
89+
print("Training stopped due to patience exhaustion")
90+
elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
91+
print("Training stopped due to reaching stopping threshold")
92+
elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
93+
print("Training completed normally without early stopping")
94+
95+
# Access human-readable message
96+
if early_stopping.stopping_reason_message:
97+
print(f"Details: {early_stopping.stopping_reason_message}")
98+
99+
The available stopping reasons are:
100+
101+
- ``NOT_STOPPED``: Training completed normally without early stopping
102+
- ``STOPPING_THRESHOLD``: Training stopped because the monitored metric reached the stopping threshold
103+
- ``DIVERGENCE_THRESHOLD``: Training stopped because the monitored metric exceeded the divergence threshold
104+
- ``PATIENCE_EXHAUSTED``: Training stopped because the metric didn't improve for the specified patience
105+
- ``NON_FINITE_METRIC``: Training stopped because the monitored metric became NaN or infinite
74106

75107
In case you need early stopping in a different part of training, subclass :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping`
76108
and change where it is called:

src/lightning/pytorch/CHANGELOG.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added time-based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))
2020

2121

22+
- Added attributes to access stopping reason in `EarlyStopping` callback ([#21188](https://github.com/Lightning-AI/pytorch-lightning/pull/21188))
23+
24+
25+
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
26+
27+
2228
### Changed
2329

2430
- Default to `weights_only=True` for `torch>=2.6` when loading checkpoints. ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))
2531

2632

27-
-
33+
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
34+
35+
36+
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))
2837

2938

3039
### Removed
@@ -34,7 +43,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3443

3544
### Fixed
3645

37-
-
46+
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))
47+
48+
49+
- Fixed case where `LightningCLI` could not be initialized with `trainer_default` containing callbacks ([#21192](https://github.com/Lightning-AI/pytorch-lightning/pull/21192))
50+
51+
52+
- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191))
3853

3954

4055
---

src/lightning/pytorch/callbacks/early_stopping.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import logging
23+
from enum import Enum
2324
from typing import Any, Callable, Optional
2425

2526
import torch
@@ -34,6 +35,16 @@
3435
log = logging.getLogger(__name__)
3536

3637

38+
class EarlyStoppingReason(Enum):
39+
"""Enum for early stopping reasons."""
40+
41+
NOT_STOPPED = 0
42+
STOPPING_THRESHOLD = 1
43+
DIVERGENCE_THRESHOLD = 2
44+
PATIENCE_EXHAUSTED = 3
45+
NON_FINITE_METRIC = 4
46+
47+
3748
class EarlyStopping(Callback):
3849
r"""Monitor a metric and stop training when it stops improving.
3950
@@ -65,6 +76,11 @@ class EarlyStopping(Callback):
6576
If this is ``False``, then the check runs at the end of the validation.
6677
log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process.
6778
79+
Attributes:
80+
stopped_epoch: The epoch at which training was stopped. 0 if training was not stopped.
81+
stopping_reason: An ``EarlyStoppingReason`` enum indicating why training was stopped.
82+
stopping_reason_message: A human-readable message explaining why training was stopped.
83+
6884
Raises:
6985
MisconfigurationException:
7086
If ``mode`` is none of ``"min"`` or ``"max"``.
@@ -75,8 +91,12 @@ class EarlyStopping(Callback):
7591
7692
>>> from lightning.pytorch import Trainer
7793
>>> from lightning.pytorch.callbacks import EarlyStopping
94+
>>> from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
7895
>>> early_stopping = EarlyStopping('val_loss')
7996
>>> trainer = Trainer(callbacks=[early_stopping])
97+
>>> # After training...
98+
>>> if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
99+
... print("Training stopped due to patience exhaustion")
80100
81101
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
82102
following arguments:
@@ -117,6 +137,8 @@ def __init__(
117137
self.divergence_threshold = divergence_threshold
118138
self.wait_count = 0
119139
self.stopped_epoch = 0
140+
self.stopping_reason = EarlyStoppingReason.NOT_STOPPED
141+
self.stopping_reason_message: Optional[str] = None
120142
self._check_on_train_epoch_end = check_on_train_epoch_end
121143
self.log_rank_zero_only = log_rank_zero_only
122144

@@ -169,6 +191,8 @@ def state_dict(self) -> dict[str, Any]:
169191
"stopped_epoch": self.stopped_epoch,
170192
"best_score": self.best_score,
171193
"patience": self.patience,
194+
"stopping_reason": self.stopping_reason.value,
195+
"stopping_reason_message": self.stopping_reason_message,
172196
}
173197

174198
@override
@@ -177,6 +201,9 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
177201
self.stopped_epoch = state_dict["stopped_epoch"]
178202
self.best_score = state_dict["best_score"]
179203
self.patience = state_dict["patience"]
204+
stopping_reason_value = state_dict.get("stopping_reason", EarlyStoppingReason.NOT_STOPPED.value)
205+
self.stopping_reason = EarlyStoppingReason(stopping_reason_value)
206+
self.stopping_reason_message = state_dict.get("stopping_reason_message")
180207

181208
def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
182209
from lightning.pytorch.trainer.states import TrainerFn
@@ -212,6 +239,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
212239
trainer.should_stop = trainer.should_stop or should_stop
213240
if should_stop:
214241
self.stopped_epoch = trainer.current_epoch
242+
self.stopping_reason_message = reason
215243
if reason and self.verbose:
216244
self._log_info(trainer, reason, self.log_rank_zero_only)
217245

@@ -220,19 +248,22 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
220248
reason = None
221249
if self.check_finite and not torch.isfinite(current):
222250
should_stop = True
251+
self.stopping_reason = EarlyStoppingReason.NON_FINITE_METRIC
223252
reason = (
224253
f"Monitored metric {self.monitor} = {current} is not finite."
225254
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
226255
)
227256
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
228257
should_stop = True
258+
self.stopping_reason = EarlyStoppingReason.STOPPING_THRESHOLD
229259
reason = (
230260
"Stopping threshold reached:"
231261
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
232262
" Signaling Trainer to stop."
233263
)
234264
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
235265
should_stop = True
266+
self.stopping_reason = EarlyStoppingReason.DIVERGENCE_THRESHOLD
236267
reason = (
237268
"Divergence threshold reached:"
238269
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
@@ -247,6 +278,7 @@ def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[s
247278
self.wait_count += 1
248279
if self.wait_count >= self.patience:
249280
should_stop = True
281+
self.stopping_reason = EarlyStoppingReason.PATIENCE_EXHAUSTED
250282
reason = (
251283
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
252284
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
484484

485485
@staticmethod
486486
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
487-
if trainer.is_global_zero:
487+
if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath):
488488
if os.path.islink(linkpath) or os.path.isfile(linkpath):
489489
os.remove(linkpath)
490490
elif os.path.isdir(linkpath):

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
277277

278278
@staticmethod
279279
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
280-
dst = getattr(new, name)
280+
# Check if the parameter has been pruned (has _orig suffix)
281+
dst = getattr(new, name + "_orig") if hasattr(new, name + "_orig") else getattr(new, name)
281282
src = getattr(old, name)
282283
if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor):
283284
return

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787
self._throughputs: dict[RunningStage, Throughput] = {}
8888
self._t0s: dict[RunningStage, float] = {}
8989
self._lengths: dict[RunningStage, int] = {}
90+
self._samples: dict[RunningStage, int] = {}
91+
self._batches: dict[RunningStage, int] = {}
9092

9193
@override
9294
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
@@ -106,8 +108,13 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
106108
def _start(self, trainer: "Trainer") -> None:
107109
stage = trainer.state.stage
108110
assert stage is not None
109-
self._throughputs[stage].reset()
110-
self._lengths[stage] = 0
111+
112+
if stage not in self._samples:
113+
self._throughputs[stage].reset()
114+
self._lengths[stage] = 0
115+
self._samples[stage] = 0
116+
self._batches[stage] = 0
117+
111118
self._t0s[stage] = time.perf_counter()
112119

113120
@torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads
@@ -133,12 +140,14 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
133140
)
134141
flops_per_batch = None
135142

136-
batch_size = self.batch_size_fn(batch)
143+
self._samples[stage] += self.batch_size_fn(batch)
144+
self._batches[stage] += 1
145+
137146
throughput.update(
138147
time=elapsed,
139-
batches=iter_num,
148+
batches=self._batches[stage],
140149
# this assumes that all iterations used the same batch size
141-
samples=iter_num * batch_size,
150+
samples=self._samples[stage],
142151
lengths=None if self.length_fn is None else self._lengths[stage],
143152
flops=flops_per_batch, # type: ignore[arg-type]
144153
)

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]:
240240

241241

242242
def _validate_callbacks_list(callbacks: list[Callback]) -> None:
243-
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
243+
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb, parent=Callback)]
244244
seen_callbacks = set()
245245
for callback in stateful_callbacks:
246246
if callback.state_key in seen_callbacks:

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,22 @@ def _run_power_scaling(
178178
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
179179
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
180180
any_success = False
181-
for _ in range(max_trials):
181+
last_successful_size = new_size
182+
for i in range(max_trials):
182183
garbage_collection_cuda()
183184

184185
# reset after each try
185186
_reset_progress(trainer)
186187

187188
try:
188189
_try_loop_run(trainer, params)
190+
last_successful_size = new_size # Store the current size before doubling
191+
192+
# Check if this is the last trial before trying to double
193+
if i + 1 >= max_trials:
194+
new_size = last_successful_size
195+
break
196+
189197
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
190198

191199
if not changed:
@@ -224,6 +232,7 @@ def _run_binary_scaling(
224232
low = 1
225233
high = None
226234
count = 0
235+
last_successful_size = new_size
227236
while True:
228237
garbage_collection_cuda()
229238

@@ -233,9 +242,14 @@ def _run_binary_scaling(
233242
try:
234243
# run loop
235244
_try_loop_run(trainer, params)
245+
last_successful_size = new_size # Store the current size before doubling
236246
count += 1
237-
if count > max_trials:
247+
248+
# Check if we've reached max_trials before trying to adjust batch size
249+
if count >= max_trials:
250+
new_size = last_successful_size
238251
break
252+
239253
# Double in size
240254
low = new_size
241255
if high:

0 commit comments

Comments
 (0)