Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c8d50bd
Weight averaging callback
Borda Jan 9, 2025
075bfcf
Merge branch 'master' into generic-weight-averaging
lantiga Feb 3, 2025
99b6638
More generic customization of the WeightAveraging callback
Feb 4, 2025
13f5298
Merge branch 'master' into generic-weight-averaging
Mar 22, 2025
aec9f6e
Training tricks mentions WeightAveraging and EMA
Mar 22, 2025
247935f
Removed logging from WeightAveraging
Apr 3, 2025
7920118
Merge branch 'master' into generic-weight-averaging
Apr 3, 2025
822231f
Fixed the documentation
Apr 3, 2025
5deb0bb
Fixed checkpoint loading with WeightAveraging
Apr 3, 2025
5a69057
WeightAveraging calls the configure_model hook but issues a warning
Apr 26, 2025
f3529f4
Merge branch 'master' into generic-weight-averaging
Apr 26, 2025
3dafb4c
Fixed unit tests
Apr 26, 2025
3fd3c22
Merge branch 'master' into generic-weight-averaging
Borda Jun 27, 2025
410fe14
The default device for the averaged model is the device of the origin…
Jul 6, 2025
3aec635
Merge branch 'master' into generic-weight-averaging
Borda Aug 8, 2025
4f38156
Added seealso to WeightAveraging and StochasticWeightAveraging
senarvi Aug 11, 2025
01161a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2025
b5e8877
More verbose description of WeightAveraging
Aug 11, 2025
908355e
Describe the magic number 7 in a comment
Aug 11, 2025
a290638
Merge branch 'master' into generic-weight-averaging
Borda Aug 12, 2025
0adb37a
Update src/lightning/pytorch/CHANGELOG.md
SkafteNicki Aug 13, 2025
c5678b5
Merge branch 'master' into generic-weight-averaging
SkafteNicki Aug 13, 2025
6012908
Merge branch 'master' into generic-weight-averaging
Borda Aug 15, 2025
54c84c7
Merge branch 'master' into generic-weight-averaging
SkafteNicki Aug 15, 2025
43c2023
Merge branch 'master' into generic-weight-averaging
Borda Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions docs/source-pytorch/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,48 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping

----------

***************************
Stochastic Weight Averaging
***************************
****************
Weight Averaging
****************

Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
it harder to end up in a local minimum during optimization.
Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
end up in a local minimum during optimization.

For a more detailed explanation of SWA and how it works,
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging`
is a generic callback that wraps the
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every
step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method.

.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html>`__) when the
procedure starts.

.. seealso::
For a more detailed explanation of SWA and how it works, read
`this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.

.. seealso::
The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback

.. testcode::

# Enable Stochastic Weight Averaging using the callback
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
from torch.optim.swa_utils import get_ema_avg_fn

# Enable Exponential Moving Average after 100 steps
class EMAWeightAveraging(WeightAveraging):
def __init__(self):
super().__init__(avg_fn=get_ema_avg_fn())
def should_update(self, step_idx=None, epoch_idx=None):
return (step_idx is not None) and (step_idx >= 100)
trainer = Trainer(callbacks=EMAWeightAveraging())

# Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))

----------

Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ callbacks
ThroughputMonitor
Timer
TQDMProgressBar
WeightAveraging

cli
-----
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Lightning has a few built-in callbacks.
StochasticWeightAveraging
Timer
TQDMProgressBar
WeightAveraging

----------

Expand Down
16 changes: 8 additions & 8 deletions docs/source-pytorch/glossary/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
Strategy registry <../advanced/strategy_registry>
Strategy integrations <../integrations/strategies/index>
Style guide <../starter/style_guide>
SWA <../advanced/training_tricks>
SLURM <../clouds/cluster_advanced>
Tensor Parallel <../advanced/model_parallel/tp>
Transfer learning <../advanced/transfer_learning>
Trainer <../common/trainer>
TorchRun (TorchElastic) <../clouds/cluster_intermediate_2>
Warnings <../advanced/warnings>
Weight averaging <../advanced/training_tricks>


########
Expand Down Expand Up @@ -326,13 +326,6 @@ Glossary
:button_link: ../starter/style_guide.html
:height: 100

.. displayitem::
:header: SWA
:description: Stochastic Weight Averaging (SWA) can make your models generalize better
:col_css: col-md-12
:button_link: ../advanced/training_tricks.html#stochastic-weight-averaging
:height: 100

.. displayitem::
:header: SLURM
:description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters
Expand Down Expand Up @@ -375,6 +368,13 @@ Glossary
:button_link: ../advanced/warnings.html
:height: 100

.. displayitem::
:header: Weight averaging
:description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better
:col_css: col-md-12
:button_link: ../advanced/training_tricks.html#weight-averaging
:height: 100

.. raw:: html

</div>
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/build_model_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
)

# access the latest state of the art techniques
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
trainer = Trainer(callbacks=[WeightAveraging(...)])

----

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/starter/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th
)
# access the latest state of the art techniques
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
trainer = L.Trainer(callbacks=[WeightAveraging(...)])
----

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- WeightAveraging callback that wraps the PyTorch AveragedModel class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))


### Changed
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging

__all__ = [
"BackboneFinetuning",
Expand All @@ -58,4 +59,5 @@
"ThroughputMonitor",
"Timer",
"TQDMProgressBar",
"WeightAveraging",
]
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(

.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.

See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Weight Averaging>`.

Arguments:

Expand Down
Loading
Loading