3030from lightning .pytorch .utilities .types import STEP_OUTPUT
3131
3232
33- def _return_true (x : int ) -> bool :
34- return True
35-
36-
37- def _return_false (x : int ) -> bool :
38- return False
39-
40-
4133class WeightAveraging (Callback ):
4234 r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
4335 (EMA) after each training step.
4436
45- The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46- model should be updated. If neither function is provided, the average model will be updated after every optimizer
47- step.
37+ The user can customize when the average model is updated by overriding the ``should_update()`` method.
4838
4939 During validation and after the training finishes, the current model parameters will be replaced with the averaged
5040 values.
@@ -55,40 +45,44 @@ class WeightAveraging(Callback):
5545 avg_fn: The averaging function used to update the parameters. The function must take in an
5646 :class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
5747 ``None``, an equally weighted average will be used.
58- update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59- model should be updated.
60- update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61- should be updated.
6248
6349 """
6450
6551 def __init__ (
6652 self ,
6753 device : Optional [Union [torch .device , int ]] = torch .device ("cpu" ),
6854 avg_fn : Optional [Callable [[Tensor , Tensor , Union [Tensor , int ]], Tensor ]] = None ,
69- update_on_step : Optional [Callable [[int ], bool ]] = None ,
70- update_on_epoch : Optional [Callable [[int ], bool ]] = None ,
7155 ):
7256 self ._device = device
7357 self ._avg_fn = avg_fn
74-
75- if (update_on_step is None ) and (update_on_epoch is None ):
76- self ._update_on_step : Callable [[int ], bool ] = _return_true
77- self ._update_on_epoch : Callable [[int ], bool ] = _return_false
78- else :
79- self ._update_on_step = _return_false if update_on_step is None else update_on_step
80- self ._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81-
8258 self ._average_model : Optional [AveragedModel ] = None
8359
8460 # Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
85- # that the average model will be first updated after the first optimizer step, which takes place after N batches
86- # when using accumulate_grad_batches=N.
61+ # that self.should_update() will be first called after the first optimizer step, which takes place after N
62+ # batches when using accumulate_grad_batches=N.
8763 self ._latest_update_step = 0
8864 # The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89- # negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
65+ # negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first
66+ # epoch.
9067 self ._latest_update_epoch = - 1
9168
69+ def should_update (self , step_idx : Optional [int ] = None , epoch_idx : Optional [int ] = None ) -> bool :
70+ """Called after every optimizer step and after every training epoch to check whether the average model should
71+ be updated.
72+
73+ One of the arguments is set to the zero-based index of the last training step or epoch. The user can customize
74+ when the average model gets updated by overriding this method.
75+
76+ Args:
77+ step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
78+ epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step.
79+
80+ Returns:
81+ ``True`` if the average model should be updated and ``False`` if not.
82+
83+ """
84+ return step_idx is not None
85+
9286 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
9387 """Called when fit, validate, test, predict, or tune begins.
9488
@@ -109,7 +103,7 @@ def on_train_batch_end(
109103 ) -> None :
110104 """Called when a training batch ends.
111105
112- Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step ()``.
106+ Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update ()``.
113107
114108 Args:
115109 trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -119,22 +113,25 @@ def on_train_batch_end(
119113 batch_idx: Index of the training batch.
120114
121115 """
122- if self ._update_on_step (trainer .global_step ) and (trainer .global_step > self ._latest_update_step ):
116+ # trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To
117+ # make step_idx consistent with epoch_idx, we'll pass a zero-based index.
118+ step_idx = trainer .global_step - 1
119+ if (trainer .global_step > self ._latest_update_step ) and self .should_update (step_idx = step_idx ):
123120 assert self ._average_model is not None
124121 self ._average_model .update_parameters (pl_module )
125122 self ._latest_update_step = trainer .global_step
126123
127124 def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
128125 """Called when a training epoch ends.
129126
130- Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch ()``.
127+ Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update ()``.
131128
132129 Args:
133130 trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134131 pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135132
136133 """
137- if self . _update_on_epoch (trainer .current_epoch ) and ( trainer .current_epoch > self . _latest_update_epoch ):
134+ if (trainer .current_epoch > self . _latest_update_epoch ) and self . should_update ( epoch_idx = trainer .current_epoch ):
138135 assert self ._average_model is not None
139136 self ._average_model .update_parameters (pl_module )
140137 self ._latest_update_epoch = trainer .current_epoch
@@ -218,17 +215,21 @@ def on_save_checkpoint(
218215
219216 """
220217 if self ._average_model is None :
221- raise Exception ("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do." )
222-
223- rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
224- average_model_state = self ._average_model .state_dict ()
225- checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
226- checkpoint ["state_dict" ] = {
227- name [7 :]: value for name , value in average_model_state .items () if name .startswith ("module." )
228- }
229- checkpoint ["averaging_state" ] = {
230- name : value for name , value in average_model_state .items () if not name .startswith ("module." )
231- }
218+ rank_zero_info (
219+ "You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state "
220+ "of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the "
221+ "average model parameters will be saved to the state_dict in the checkpoint."
222+ )
223+ else :
224+ rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
225+ average_model_state = self ._average_model .state_dict ()
226+ checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
227+ checkpoint ["state_dict" ] = {
228+ name [7 :]: value for name , value in average_model_state .items () if name .startswith ("module." )
229+ }
230+ checkpoint ["averaging_state" ] = {
231+ name : value for name , value in average_model_state .items () if not name .startswith ("module." )
232+ }
232233
233234 def on_load_checkpoint (
234235 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : dict [str , Any ]
@@ -244,9 +245,12 @@ def on_load_checkpoint(
244245
245246 """
246247 if self ._average_model is None :
247- raise Exception ("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do." )
248-
249- if ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
248+ rank_zero_warn (
249+ "You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The "
250+ "WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, "
251+ "you can ignore this warning. To disable the warning, remove the WeightAveraging callback."
252+ )
253+ elif ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
250254 rank_zero_info ("Found current_model_state in the checkpoint. This will be used to initialize the model." )
251255 average_model_state = {"module." + name : value for name , value in checkpoint ["state_dict" ].items ()}
252256 average_model_state |= checkpoint ["averaging_state" ]
0 commit comments