@@ -35,12 +35,11 @@ class WeightAveraging(Callback):
3535 r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
3636 (EMA) after each training step.
3737
38- Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. There are a couple of
39- differences to the default values, however. By default, the average model is stored on the CPU. If ``device`` is set
40- to ``None``, the device will be inferred from the original model. By default, the callback will compute running
41- averages for both the parameters and the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only
42- the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using
43- ``torch.optim.swa_utils.update_bn()``).
38+ Arguments given to the constructor will be passed to the :class:`AveragedModel` constructor. If no ``device`` is
39+ specified, the device of the original model will be used. Contrary to :class:`AveragedModel`, ``use_buffers`` is set
40+ to ``True`` by default. That is, by default the callback will compute running averages for both the parameters and
41+ the buffers of the model. Setting ``use_buffers`` to ``False`` will cause only the model parameters to be averaged,
42+ leaving updating the batch normalization statistics to the user (using ``torch.optim.swa_utils.update_bn()``).
4443
4544 You can provide a custom averaging function with the ``avg_fn`` or ``multi_avg_fn`` parameter. See the
4645 :class:`AveragedModel` class for details. If no averaging function is provided, the default is to compute the
@@ -79,8 +78,9 @@ def should_update(self, step_idx=None, epoch_idx=None):
7978 trainer.fit(model, dataloader)
8079
8180 Args:
82- device: If provided, the :class:`AveragedModel` will be stored on the ``device``. If ``None`` the device will be
83- inferred from the original model.
81+ device: By default, the :class:`AveragedModel` will be stored on the same device as the original model. If the
82+ ``device`` argument is provided, the :class:`AveragedModel` will be stored on this device instead. If you
83+ run out of GPU memory, you might want to use ``"cpu"``.
8484 use_buffers: If ``False``, the buffers of the model will not be averaged.
8585 kwargs: Additional keyword arguments to be passed to the :class:`AveragedModel` constructor, such as ``avg_fn``
8686 or ``multi_avg_fn``.
@@ -89,7 +89,7 @@ def should_update(self, step_idx=None, epoch_idx=None):
8989
9090 def __init__ (
9191 self ,
92- device : Optional [Union [torch .device , str , int ]] = "cpu" ,
92+ device : Optional [Union [torch .device , str , int ]] = None ,
9393 use_buffers : bool = True ,
9494 ** kwargs : Any ,
9595 ) -> None :
0 commit comments