@@ -50,23 +50,44 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping
5050
5151----------
5252
53- ***************************
54- Stochastic Weight Averaging
55- ***************************
53+ ****************
54+ Weight Averaging
55+ ****************
5656
57- Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
58- This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
59- it harder to end up in a local minimum during optimization.
57+ Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
58+ models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
59+ end up in a local minimum during optimization.
6060
61- For a more detailed explanation of SWA and how it works,
62- read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging >`__ by the PyTorch team.
61+ Lightning provides two callbacks to facilitate weight averaging. :class: `~lightning.pytorch.callbacks.WeightAveraging `
62+ is a generic callback that wraps the
63+ `AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html >`__ class from
64+ PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used and it can be customized to run at specific steps
65+ or epochs.
6366
64- .. seealso :: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
67+ The older :class: `~lightning.pytorch.callbacks.StochasticWeightAveraging ` callback is specific to SWA. It starts the SWA
68+ procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
69+ learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html >`__) when the
70+ procedure starts.
71+
72+ .. seealso ::
73+ For a more detailed explanation of SWA and how it works, read
74+ `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging >`__ by the PyTorch team.
6575
6676.. testcode ::
6777
68- # Enable Stochastic Weight Averaging using the callback
69- trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
78+ from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
79+ from torch.optim.swa_utils import get_ema_avg_fn
80+
81+ # Enable Exponential Moving Average after 100 steps
82+ class EMAWeightAveraging(WeightAveraging):
83+ def __init__(self):
84+ super().__init__(avg_fn=get_ema_avg_fn())
85+ def should_update(self, step_idx=None, epoch_idx=None):
86+ return (step_idx is not None) and (step_idx >= 100)
87+ trainer = Trainer(callbacks=EMAWeightAveraging())
88+
89+ # Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
90+ trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))
7091
7192----------
7293
0 commit comments