3030from lightning .pytorch .utilities .types import STEP_OUTPUT
3131
3232
33+ def _return_true (x : int ) -> bool :
34+ return True
35+
36+
37+ def _return_false (x : int ) -> bool :
38+ return False
39+
40+
3341class WeightAveraging (Callback ):
3442 r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
3543 (EMA) after each training step.
3644
45+ The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46+ model should be updated. If neither function is provided, the average model will be updated after every optimizer
47+ step.
48+
3749 During validation and after the training finishes, the current model parameters will be replaced with the averaged
3850 values.
3951
@@ -43,22 +55,39 @@ class WeightAveraging(Callback):
4355 avg_fn: The averaging function used to update the parameters. The function must take in an
4456 :class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
4557 ``None``, an equally weighted average will be used.
58+ update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59+ model should be updated.
60+ update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61+ should be updated.
4662
4763 """
4864
4965 def __init__ (
5066 self ,
51- device : torch .device | str | None = torch .device ("cpu" ),
67+ device : torch .device | int | None = torch .device ("cpu" ),
5268 avg_fn : Callable [[Tensor , Tensor , Tensor | int ], Tensor ] | None = None ,
69+ update_on_step : Callable [[int ], bool ] | None = None ,
70+ update_on_epoch : Callable [[int ], bool ] | None = None ,
5371 ):
5472 self ._device = device
5573 self ._avg_fn = avg_fn
74+
75+ if (update_on_step is None ) and (update_on_epoch is None ):
76+ self ._update_on_step : Callable [[int ], bool ] = _return_true
77+ self ._update_on_epoch : Callable [[int ], bool ] = _return_false
78+ else :
79+ self ._update_on_step = _return_false if update_on_step is None else update_on_step
80+ self ._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81+
5682 self ._average_model : AveragedModel | None = None
5783
5884 # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
5985 # that the average model will be first updated after the first optimizer step, which takes place after N batches
6086 # when using accumulate_grad_batches=N.
6187 self ._latest_update_step = 0
88+ # The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89+ # negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
90+ self ._latest_update_epoch = - 1
6291
6392 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
6493 """Called when fit, validate, test, predict, or tune begins.
@@ -80,7 +109,7 @@ def on_train_batch_end(
80109 ) -> None :
81110 """Called when a training batch ends.
82111
83- Updates the :class:`AveragedModel` parameters.
112+ Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()`` .
84113
85114 Args:
86115 trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -90,11 +119,26 @@ def on_train_batch_end(
90119 batch_idx: Index of the training batch.
91120
92121 """
93- if trainer .global_step > self ._latest_update_step :
122+ if self . _update_on_step ( trainer .global_step ) and ( trainer . global_step > self ._latest_update_step ) :
94123 assert self ._average_model is not None
95124 self ._average_model .update_parameters (pl_module )
96125 self ._latest_update_step = trainer .global_step
97126
127+ def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
128+ """Called when a training epoch ends.
129+
130+ Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
131+
132+ Args:
133+ trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134+ pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135+
136+ """
137+ if self ._update_on_epoch (trainer .current_epoch ) and (trainer .current_epoch > self ._latest_update_epoch ):
138+ assert self ._average_model is not None
139+ self ._average_model .update_parameters (pl_module )
140+ self ._latest_update_epoch = trainer .current_epoch
141+
98142 def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
99143 """Called when training ends.
100144
@@ -173,6 +217,7 @@ def on_save_checkpoint(
173217 checkpoint: The checkpoint dictionary that will be saved.
174218
175219 """
220+ assert self ._average_model is not None
176221 rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
177222 average_model_state = self ._average_model .state_dict ()
178223 checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
@@ -196,6 +241,7 @@ def on_load_checkpoint(
196241 checkpoint: The full checkpoint dictionary that got loaded by the Trainer.
197242
198243 """
244+ assert self ._average_model is not None
199245 if ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
200246 rank_zero_info ("Found current_model_state in the checkpoint. This will be used to initialize the model." )
201247 average_model_state = {"module." + name : value for name , value in checkpoint ["state_dict" ].items ()}
@@ -216,6 +262,7 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
216262 pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
217263
218264 """
265+ assert self ._average_model is not None
219266 average_params = itertools .chain (self ._average_model .module .parameters (), self ._average_model .module .buffers ())
220267 current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
221268 for average_param , current_param in zip (average_params , current_params ):
@@ -230,6 +277,7 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
230277 pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
231278
232279 """
280+ assert self ._average_model is not None
233281 average_params = itertools .chain (self ._average_model .module .parameters (), self ._average_model .module .buffers ())
234282 current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
235283 for average_param , current_param in zip (average_params , current_params ):
0 commit comments