Skip to content

Commit c07c9c9

Browse files
authored
Merge branch 'master' into mlflow-logging-fix
2 parents 90c6f4f + 1ec459f commit c07c9c9

File tree

21 files changed

+1174
-59
lines changed

21 files changed

+1174
-59
lines changed

.github/workflows/ci-tests-pytorch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
140140
-U --upgrade-strategy=eager --prefer-binary \
141141
-r requirements/_integrations/accelerators.txt \
142-
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}"
142+
--extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" --find-links="https://download.pytorch.org/whl/torch-tensorrt"
143143
pip list
144144
- name: Drop LAI from extensions
145145
if: ${{ matrix.pkg-name != 'lightning' }}

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,48 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping
5050

5151
----------
5252

53-
***************************
54-
Stochastic Weight Averaging
55-
***************************
53+
****************
54+
Weight Averaging
55+
****************
5656

57-
Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
58-
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
59-
it harder to end up in a local minimum during optimization.
57+
Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
58+
models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
59+
end up in a local minimum during optimization.
6060

61-
For a more detailed explanation of SWA and how it works,
62-
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
61+
Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging`
62+
is a generic callback that wraps the
63+
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
64+
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every
65+
step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method.
6366

64-
.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
67+
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
68+
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
69+
learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html>`__) when the
70+
procedure starts.
71+
72+
.. seealso::
73+
For a more detailed explanation of SWA and how it works, read
74+
`this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
75+
76+
.. seealso::
77+
The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and
78+
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
6579

6680
.. testcode::
6781

68-
# Enable Stochastic Weight Averaging using the callback
69-
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
82+
from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
83+
from torch.optim.swa_utils import get_ema_avg_fn
84+
85+
# Enable Exponential Moving Average after 100 steps
86+
class EMAWeightAveraging(WeightAveraging):
87+
def __init__(self):
88+
super().__init__(avg_fn=get_ema_avg_fn())
89+
def should_update(self, step_idx=None, epoch_idx=None):
90+
return (step_idx is not None) and (step_idx >= 100)
91+
trainer = Trainer(callbacks=EMAWeightAveraging())
92+
93+
# Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
94+
trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))
7095

7196
----------
7297

docs/source-pytorch/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ callbacks
4848
ThroughputMonitor
4949
Timer
5050
TQDMProgressBar
51+
WeightAveraging
5152

5253
cli
5354
-----

docs/source-pytorch/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Lightning has a few built-in callbacks.
8383
StochasticWeightAveraging
8484
Timer
8585
TQDMProgressBar
86+
WeightAveraging
8687

8788
----------
8889

docs/source-pytorch/glossary/index.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
Strategy registry <../advanced/strategy_registry>
4343
Strategy integrations <../integrations/strategies/index>
4444
Style guide <../starter/style_guide>
45-
SWA <../advanced/training_tricks>
4645
SLURM <../clouds/cluster_advanced>
4746
Tensor Parallel <../advanced/model_parallel/tp>
4847
Transfer learning <../advanced/transfer_learning>
4948
Trainer <../common/trainer>
5049
TorchRun (TorchElastic) <../clouds/cluster_intermediate_2>
5150
Warnings <../advanced/warnings>
51+
Weight averaging <../advanced/training_tricks>
5252

5353

5454
########
@@ -326,13 +326,6 @@ Glossary
326326
:button_link: ../starter/style_guide.html
327327
:height: 100
328328

329-
.. displayitem::
330-
:header: SWA
331-
:description: Stochastic Weight Averaging (SWA) can make your models generalize better
332-
:col_css: col-md-12
333-
:button_link: ../advanced/training_tricks.html#stochastic-weight-averaging
334-
:height: 100
335-
336329
.. displayitem::
337330
:header: SLURM
338331
:description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters
@@ -375,6 +368,13 @@ Glossary
375368
:button_link: ../advanced/warnings.html
376369
:height: 100
377370

371+
.. displayitem::
372+
:header: Weight averaging
373+
:description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better
374+
:col_css: col-md-12
375+
:button_link: ../advanced/training_tricks.html#weight-averaging
376+
:height: 100
377+
378378
.. raw:: html
379379

380380
</div>

docs/source-pytorch/model/build_model_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
2727
)
2828
2929
# access the latest state of the art techniques
30-
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
30+
trainer = Trainer(callbacks=[WeightAveraging(...)])
3131
3232
----
3333

docs/source-pytorch/starter/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th
252252
)
253253
254254
# access the latest state of the art techniques
255-
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
255+
trainer = L.Trainer(callbacks=[WeightAveraging(...)])
256256
257257
----
258258

requirements/pytorch/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in
1919

2020
tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger`
2121
mlflow >=3.0.0, <4.0 # for `MLFlowLogger
22+
23+
--find-links https://download.pytorch.org/whl/torch-tensorrt
24+
torch-tensorrt; platform_system == "Linux" and python_version >= "3.12"

src/lightning/fabric/CHANGELOG.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,28 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
### Changed
2121

22-
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))
22+
-
2323

2424

2525
### Fixed
2626

27-
- Fix XLA strategy to add support for `global_ordinal`, `local_ordinal`, `world_size` which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852))
27+
-
2828

2929

30-
- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975))
30+
---
3131

32+
## [2.5.3] - 2025-08-13
33+
34+
### Changed
35+
36+
- Enable "auto" for `devices` and `accelerator` as CLI arguments ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/20913))
37+
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))
38+
39+
### Fixed
40+
41+
- Fixed XLA strategy to add support for `global_ordinal`, `local_ordinal`, `world_size` which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852))
42+
- Fixed remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975))
3243

33-
---
3444

3545
## [2.5.2] - 2025-3-20
3646

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
"""Return corresponding distributed backend for a given device."""
323+
device_backend_map = torch.distributed.Backend.default_device_backend_map
324+
if device.type in device_backend_map:
325+
return device_backend_map[device.type]
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

0 commit comments

Comments
 (0)