Skip to content

Commit 2a4ba44

Browse files
authored
Merge branch 'master' into fix/fsdp-mixed-precision
2 parents d80ca08 + 5b52e8f commit 2a4ba44

File tree

14 files changed

+253
-41
lines changed

14 files changed

+253
-41
lines changed

.github/CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ We welcome any useful contribution! For your convenience here's a recommended wo
212212
- [Test README](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/README.md)
213213
- [CI/CD README](https://github.com/Lightning-AI/pytorch-lightning/tree/master/.github/workflows#readme)
214214

215+
1. Once you have a PR opened (and thereby a PR number), please update the respective changelog for [fabric](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/CHANGELOG.md) or [pytorch](https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/CHANGELOG.md) subpackage depending on where you made your changes.
216+
215217
1. When you feel ready for integrating your work, mark your PR "Ready for review".
216218

217219
- Your code should be readable and follow the project's design principles.

.github/workflows/ci-tests-fabric.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,20 @@ jobs:
4040
matrix:
4141
os: [macOS-14, ubuntu-22.04, windows-2022]
4242
config:
43-
# only run PyTorch latest
43+
# Test unified "lightning" package with PyTorch 2.1-2.5
4444
- { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4545
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
4646
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
4747
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
4848
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
4949

50-
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
50+
# Test "fabric" package with PyTorch 2.6-2.9
5151
- { pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" }
52-
53-
# "fabric" installs the standalone package
54-
- { pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
55-
56-
# adding recently cut Torch 2.7 - FUTURE
52+
- { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" }
5753
- { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
54+
- { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.9" }
5855

59-
# "oldest" versions tests, only on minimum Python
56+
# Test minimum supported versions (oldest)
6057
- { pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
6158
timeout-minutes: 25 # because of building grpcio on Mac
6259
env:

.github/workflows/ci-tests-pytorch.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,20 @@ jobs:
4444
matrix:
4545
os: [macOS-14, ubuntu-22.04, windows-2022]
4646
config:
47-
# only run PyTorch latest
47+
# Test unified "lightning" package with PyTorch 2.1-2.5
4848
- { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4949
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
5050
- { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
5151
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
5252
- { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
5353

54-
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
54+
# Test "pytorch" package with PyTorch 2.6-2.9
5555
- { pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" }
56-
57-
# "pytorch" installs the standalone package
58-
- { pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
59-
60-
# adding recently cut Torch 2.7 - FUTURE
56+
- { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" }
6157
- { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
58+
- { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.9" }
6259

63-
# "oldest" versions tests, only on minimum Python
60+
# Test minimum supported versions (oldest)
6461
- { pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
6562
timeout-minutes: 50
6663
env:

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28-
-
28+
- Fixed `EADDRINUSE` errors in distributed tests with port manager and retry logic ([#21309](https://github.com/Lightning-AI/pytorch-lightning/pull/21309))
29+
30+
31+
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
32+
2933

3034

3135
---

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))
2626

2727

28+
- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))
29+
30+
2831
### Changed
2932

3033
- Expose `weights_only` argument for `Trainer.{fit,validate,test,predict}` and let `torch` handle default value ([#21072](https://github.com/Lightning-AI/pytorch-lightning/pull/21072))

src/lightning/pytorch/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
3333
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
3434
from lightning.pytorch.callbacks.timer import Timer
35-
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
35+
from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging
3636

3737
__all__ = [
3838
"BackboneFinetuning",
@@ -59,5 +59,6 @@
5959
"ThroughputMonitor",
6060
"Timer",
6161
"TQDMProgressBar",
62+
"EMAWeightAveraging",
6263
"WeightAveraging",
6364
]

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Optional, Union
2222

2323
import torch
24-
from torch.optim.swa_utils import AveragedModel
24+
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
2525
from typing_extensions import override
2626

2727
import lightning.pytorch as pl
@@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
361361
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
362362
for average_param, current_param in zip(average_params, current_params):
363363
current_param.data.copy_(average_param.data)
364+
365+
366+
class EMAWeightAveraging(WeightAveraging):
367+
"""Exponential Moving Average (EMA) Weight Averaging callback."""
368+
369+
def __init__(
370+
self,
371+
device: Optional[Union[torch.device, str, int]] = None,
372+
use_buffers: bool = True,
373+
decay: float = 0.999,
374+
update_every_n_steps: int = 1,
375+
update_starting_at_step: Optional[int] = None,
376+
update_starting_at_epoch: Optional[int] = None,
377+
**kwargs: Any,
378+
):
379+
super().__init__(
380+
device=device,
381+
use_buffers=use_buffers,
382+
**kwargs,
383+
avg_fn=get_ema_avg_fn(decay=decay),
384+
)
385+
386+
self.update_every_n_steps = update_every_n_steps
387+
self.update_starting_at_step = update_starting_at_step
388+
self.update_starting_at_epoch = update_starting_at_epoch
389+
390+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
391+
"""Decide when to update the model weights.
392+
393+
Args:
394+
step_idx: The current step index.
395+
epoch_idx: The current epoch index.
396+
Returns:
397+
bool: True if the model weights should be updated, False otherwise.
398+
399+
"""
400+
if step_idx is not None:
401+
# Check step-based conditions only if we have a valid step_idx
402+
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
403+
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
404+
if meets_step_requirement and meets_step_frequency:
405+
return True
406+
407+
if epoch_idx is not None:
408+
# Check epoch-based condition only if we specify one
409+
meets_epoch_requirement = (
410+
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
411+
)
412+
if meets_epoch_requirement:
413+
return True
414+
415+
return False

src/lightning/pytorch/core/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
6969
"""Called in the training loop before anything happens for that batch.
7070
7171
If you return -1 here, you will skip training for the rest of the current epoch.
72+
Learning rate scheduler will still be stepped at the end of epoch.
7273
7374
Args:
7475
batch: The batched data as it is returned by the training DataLoader.

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
325325
trainer._logger_connector.on_batch_start(batch)
326326

327327
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
328+
should_skip_rest_of_epoch = False
329+
328330
if batch is None and not using_dataloader_iter:
329331
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
330332
else:
331333
# hook
332334
call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
333335
response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx)
334336
call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx)
335-
if response == -1:
336-
self.batch_progress.increment_processed()
337-
raise StopIteration
338-
339-
self.batch_progress.increment_started()
340-
341-
kwargs = (
342-
self._build_kwargs(OrderedDict(), batch, batch_idx)
343-
if not using_dataloader_iter
344-
else OrderedDict(any=dataloader_iter)
345-
)
346-
with trainer.profiler.profile("run_training_batch"):
347-
if trainer.lightning_module.automatic_optimization:
348-
# in automatic optimization, there can only be one optimizer
349-
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
350-
else:
351-
batch_output = self.manual_optimization.run(kwargs)
337+
should_skip_rest_of_epoch = response == -1
338+
# Signal this is the last batch for the current epoch
339+
if should_skip_rest_of_epoch:
340+
self.batch_progress.increment_by(0, is_last_batch=True)
341+
else:
342+
self.batch_progress.increment_started()
343+
344+
kwargs = (
345+
self._build_kwargs(OrderedDict(), batch, batch_idx)
346+
if not using_dataloader_iter
347+
else OrderedDict(any=dataloader_iter)
348+
)
349+
with trainer.profiler.profile("run_training_batch"):
350+
if trainer.lightning_module.automatic_optimization:
351+
# in automatic optimization, there can only be one optimizer
352+
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
353+
else:
354+
batch_output = self.manual_optimization.run(kwargs)
352355

353356
self.batch_progress.increment_processed()
354357

@@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
358361
if self._num_ready_batches_reached():
359362
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
360363

364+
if should_skip_rest_of_epoch:
365+
# Only raise StopIteration now so that the training epoch loop can finish
366+
raise StopIteration
367+
361368
if using_dataloader_iter:
362369
# update the hook kwargs now that the step method might have consumed the iterator
363370
batch = data_fetcher._batch

src/lightning/pytorch/trainer/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,14 +563,16 @@ def fit(
563563
recommend using ``weights_only=True``. For more information, please refer to the
564564
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
565565
566+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
567+
568+
:rtype: :py:obj:`None`
569+
566570
Raises:
567571
TypeError:
568572
If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than
569573
2.0.0 and if ``model`` is not :class:`~lightning.pytorch.core.LightningModule` or
570574
:class:`torch._dynamo.OptimizedModule` for torch versions greater than or equal to 2.0.0 .
571575
572-
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
573-
574576
"""
575577
model = _maybe_unwrap_optimized(model)
576578
self.strategy._lightning_module = model

0 commit comments

Comments
 (0)