Skip to content

Commit aec9f6e

Browse files
author
Seppo Enarvi
committed
Training tricks mentions WeightAveraging and EMA
1 parent 13f5298 commit aec9f6e

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,44 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping
5050

5151
----------
5252

53-
***************************
54-
Stochastic Weight Averaging
55-
***************************
53+
****************
54+
Weight Averaging
55+
****************
5656

57-
Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
58-
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
59-
it harder to end up in a local minimum during optimization.
57+
Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
58+
models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
59+
end up in a local minimum during optimization.
6060

61-
For a more detailed explanation of SWA and how it works,
62-
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
61+
Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging`
62+
is a generic callback that wraps the
63+
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
64+
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used and it can be customized to run at specific steps
65+
or epochs.
6366

64-
.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
67+
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
68+
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
69+
learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html>`__) when the
70+
procedure starts.
71+
72+
.. seealso::
73+
For a more detailed explanation of SWA and how it works, read
74+
`this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
6575

6676
.. testcode::
6777

68-
# Enable Stochastic Weight Averaging using the callback
69-
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
78+
from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
79+
from torch.optim.swa_utils import get_ema_avg_fn
80+
81+
# Enable Exponential Moving Average after 100 steps
82+
class EMAWeightAveraging(WeightAveraging):
83+
def __init__(self):
84+
super().__init__(avg_fn=get_ema_avg_fn())
85+
def should_update(self, step_idx=None, epoch_idx=None):
86+
return (step_idx is not None) and (step_idx >= 100)
87+
trainer = Trainer(callbacks=EMAWeightAveraging())
88+
89+
# Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
90+
trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))
7091

7192
----------
7293

docs/source-pytorch/model/build_model_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
2727
)
2828
2929
# access the latest state of the art techniques
30-
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
30+
trainer = Trainer(callbacks=[WeightAveraging(...)])
3131
3232
----
3333

docs/source-pytorch/starter/introduction.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th
252252
)
253253
254254
# access the latest state of the art techniques
255-
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
255+
trainer = L.Trainer(callbacks=[WeightAveraging(...)])
256256
257257
----
258258

0 commit comments

Comments
 (0)