Skip to content

Commit b46188b

Browse files
author
Seppo Enarvi
committed
Any keyword arguments will be passed to the AveragedModel constructor
1 parent 42d91cd commit b46188b

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818

1919
import itertools
2020
from copy import deepcopy
21-
from typing import Any, Callable, Optional, Union
21+
from typing import Any, Optional, Union
2222

2323
import torch
24-
from torch import Tensor
2524
from torch.optim.swa_utils import AveragedModel
2625
from typing_extensions import override
2726

@@ -35,32 +34,63 @@ class WeightAveraging(Callback):
3534
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
3635
(EMA) after each training step.
3736
38-
The user can customize when the average model is updated by overriding the ``should_update()`` method.
37+
Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of
38+
differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set
39+
to ``None``, the device will be inferred from the original model. By default, the callback will compute running
40+
averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only
41+
the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using
42+
``torch.optim.swa_utils.update_bn()``).
43+
44+
You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the
45+
:class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the
46+
equally-weighted average of the weights (SWA).
47+
48+
You can customize when the average model is updated by overriding the ``should_update()`` method. The callback calls
49+
it with either ``step_idx`` or ``epoch_idx`` and the method returns a boolean indicating whether to update after the
50+
given step or epoch. The default is to update after every step.
3951
4052
During validation and after the training finishes, the current model parameters will be replaced with the averaged
4153
values.
4254
55+
Example::
56+
57+
from lightning.pytorch.callbacks import WeightAveraging
58+
from torch.optim.swa_utils import get_ema_avg_fn
59+
60+
class EMAWeightAveraging(WeightAveraging):
61+
def __init__(self):
62+
super().__init__(avg_fn=get_ema_avg_fn())
63+
64+
def should_update(self, step_idx=None, epoch_idx=None):
65+
# Start after 100 steps.
66+
return (step_idx is not None) and (step_idx >= 100)
67+
68+
trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10)
69+
trainer.fit(model, dataloader)
70+
4371
Args:
4472
device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
4573
inferred from the original model.
46-
avg_fn: The averaging function used to update the parameters. The function must take in an
47-
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
48-
``None``, an equally weighted average will be used.
74+
use_buffers: If ``False``, the buffers of the model will not be averaged.
75+
kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn``
76+
or ``multi_avg_fn``.
4977
5078
"""
5179

5280
def __init__(
5381
self,
5482
device: Optional[Union[torch.device, str, int]] = "cpu",
55-
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
56-
):
83+
use_buffers: bool = True,
84+
**kwargs: Any,
85+
) -> None:
5786
# The default value is a string so that jsonargparse knows how to serialize it.
5887
if isinstance(device, str):
5988
self._device: Optional[Union[torch.device, int]] = torch.device(device)
6089
else:
6190
self._device = device
91+
self._use_buffers = use_buffers
92+
self._kwargs = kwargs
6293

63-
self._avg_fn = avg_fn
6494
self._average_model: Optional[AveragedModel] = None
6595

6696
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
@@ -76,8 +106,9 @@ def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int]
76106
"""Called after every optimizer step and after every training epoch to check whether the average model should
77107
be updated.
78108
79-
One of the arguments is set to the zero-based index of the last training step or epoch. The user can customize
80-
when the average model gets updated by overriding this method.
109+
One of the arguments is set to the zero-based index of the last training step or epoch. The default
110+
implementation returns ``True`` when any ``step_idx`` is provided. The user can customize when the average model
111+
gets updated by overriding this method.
81112
82113
Args:
83114
step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
@@ -103,7 +134,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
103134
"""
104135
if stage == "fit":
105136
device = self._device or pl_module.device
106-
self._average_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True)
137+
self._average_model = AveragedModel(
138+
model=pl_module, device=device, use_buffers=self._use_buffers, **self._kwargs
139+
)
107140

108141
@override
109142
def on_train_batch_end(

0 commit comments

Comments
 (0)