1818
1919import itertools
2020from copy import deepcopy
21- from typing import Any , Callable , Optional , Union
21+ from typing import Any , Optional , Union
2222
2323import torch
24- from torch import Tensor
2524from torch .optim .swa_utils import AveragedModel
2625from typing_extensions import override
2726
@@ -35,32 +34,63 @@ class WeightAveraging(Callback):
3534 r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
3635 (EMA) after each training step.
3736
38- The user can customize when the average model is updated by overriding the ``should_update()`` method.
37+ Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of
38+ differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set
39+ to ``None``, the device will be inferred from the original model. By default, the callback will compute running
40+ averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only
41+ the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using
42+ ``torch.optim.swa_utils.update_bn()``).
43+
44+ You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the
45+ :class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the
46+ equally-weighted average of the weights (SWA).
47+
48+ You can customize when the average model is updated by overriding the ``should_update()`` method. The callback calls
49+ it with either ``step_idx`` or ``epoch_idx`` and the method returns a boolean indicating whether to update after the
50+ given step or epoch. The default is to update after every step.
3951
4052 During validation and after the training finishes, the current model parameters will be replaced with the averaged
4153 values.
4254
55+ Example::
56+
57+ from lightning.pytorch.callbacks import WeightAveraging
58+ from torch.optim.swa_utils import get_ema_avg_fn
59+
60+ class EMAWeightAveraging(WeightAveraging):
61+ def __init__(self):
62+ super().__init__(avg_fn=get_ema_avg_fn())
63+
64+ def should_update(self, step_idx=None, epoch_idx=None):
65+ # Start after 100 steps.
66+ return (step_idx is not None) and (step_idx >= 100)
67+
68+ trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10)
69+ trainer.fit(model, dataloader)
70+
4371 Args:
4472 device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
4573 inferred from the original model.
46- avg_fn: The averaging function used to update the parameters. The function must take in an
47- :class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
48- ``None``, an equally weighted average will be used .
74+ use_buffers: If ``False``, the buffers of the model will not be averaged.
75+ kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn``
76+ or ``multi_avg_fn`` .
4977
5078 """
5179
5280 def __init__ (
5381 self ,
5482 device : Optional [Union [torch .device , str , int ]] = "cpu" ,
55- avg_fn : Optional [Callable [[Tensor , Tensor , Union [Tensor , int ]], Tensor ]] = None ,
56- ):
83+ use_buffers : bool = True ,
84+ ** kwargs : Any ,
85+ ) -> None :
5786 # The default value is a string so that jsonargparse knows how to serialize it.
5887 if isinstance (device , str ):
5988 self ._device : Optional [Union [torch .device , int ]] = torch .device (device )
6089 else :
6190 self ._device = device
91+ self ._use_buffers = use_buffers
92+ self ._kwargs = kwargs
6293
63- self ._avg_fn = avg_fn
6494 self ._average_model : Optional [AveragedModel ] = None
6595
6696 # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
@@ -76,8 +106,9 @@ def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int]
76106 """Called after every optimizer step and after every training epoch to check whether the average model should
77107 be updated.
78108
79- One of the arguments is set to the zero-based index of the last training step or epoch. The user can customize
80- when the average model gets updated by overriding this method.
109+ One of the arguments is set to the zero-based index of the last training step or epoch. The default
110+ implementation returns ``True`` when any ``step_idx`` is provided. The user can customize when the average model
111+ gets updated by overriding this method.
81112
82113 Args:
83114 step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
@@ -103,7 +134,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
103134 """
104135 if stage == "fit" :
105136 device = self ._device or pl_module .device
106- self ._average_model = AveragedModel (model = pl_module , device = device , avg_fn = self ._avg_fn , use_buffers = True )
137+ self ._average_model = AveragedModel (
138+ model = pl_module , device = device , use_buffers = self ._use_buffers , ** self ._kwargs
139+ )
107140
108141 @override
109142 def on_train_batch_end (
0 commit comments