Skip to content

Commit 410fe14

Browse files
author
Seppo Enarvi
committed
The default device for the averaged model is the device of the original model
1 parent 3fd3c22 commit 410fe14

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,11 @@ class WeightAveraging(Callback):
3535
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
3636
(EMA) after each training step.
3737
38-
Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of
39-
differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set
40-
to ``None``, the device will be inferred from the original model. By default, the callback will compute running
41-
averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only
42-
the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using
43-
``torch.optim.swa_utils.update_bn()``).
38+
Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. If no ``device`` is
39+
specified, the device of the original model will be used. Contrary to :class:`AveragedModel`, ``use_buffers`` is set
40+
to ``True`` by default. That is, by default the callback will compute running averages for both the parameters and
41+
the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only the model parameters to be averaged,
42+
leaving updating the batch normalization statistics to the user (using ``torch.optim.swa_utils.update_bn()``).
4443
4544
You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the
4645
:class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the
@@ -79,8 +78,9 @@ def should_update(self, step_idx=None, epoch_idx=None):
7978
trainer.fit(model, dataloader)
8079
8180
Args:
82-
device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
83-
inferred from the original model.
81+
device: By default, the :class:`AveragedModel` will be stored on the same device as the original model. If the
82+
``device`` argument is provided, the :class:`AveragedModel` will be stored on this device instead. If you
83+
run out of GPU memory, you might want to use ``"cpu"``.
8484
use_buffers: If ``False``, the buffers of the model will not be averaged.
8585
kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn``
8686
or ``multi_avg_fn``.
@@ -89,7 +89,7 @@ def should_update(self, step_idx=None, epoch_idx=None):
8989

9090
def __init__(
9191
self,
92-
device: Optional[Union[torch.device, str, int]] = "cpu",
92+
device: Optional[Union[torch.device, str, int]] = None,
9393
use_buffers: bool = True,
9494
**kwargs: Any,
9595
) -> None:

0 commit comments

Comments
 (0)