@@ -139,6 +139,8 @@ def swa_start(self) -> int:
139139
140140 @property
141141 def swa_end (self ) -> int :
142+ if self ._max_epochs == - 1 :
143+ return float ("inf" ) # type: ignore[return-value]
142144 return self ._max_epochs - 1 # 0-based
143145
144146 @staticmethod
@@ -163,12 +165,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
163165
164166 assert trainer .max_epochs is not None
165167 if isinstance (self ._swa_epoch_start , float ):
168+ if trainer .max_epochs == - 1 :
169+ raise MisconfigurationException (
170+ "SWA with `swa_epoch_start` as a float is not supported when `max_epochs=-1`. "
171+ "Please provide `swa_epoch_start` as an integer."
172+ )
166173 self ._swa_epoch_start = int (trainer .max_epochs * self ._swa_epoch_start )
167174
168175 self ._model_contains_batch_norm = self .pl_module_contains_batch_norm (pl_module )
169176
170177 self ._max_epochs = trainer .max_epochs
171- if self ._model_contains_batch_norm :
178+ if self ._model_contains_batch_norm and trainer . max_epochs != - 1 :
172179 # virtually increase max_epochs to perform batch norm update on latest epoch.
173180 assert trainer .fit_loop .max_epochs is not None
174181 trainer .fit_loop .max_epochs += 1
@@ -243,7 +250,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
243250 self ._latest_update_epoch = trainer .current_epoch
244251
245252 # Note: No > here in case the callback is saved with the model and training continues
246- if trainer .current_epoch == self .swa_end + 1 :
253+ if self . _max_epochs != - 1 and trainer .current_epoch == self .swa_end + 1 :
247254 # Transfer weights from average model to pl_module
248255 assert self ._average_model is not None
249256 self .transfer_weights (self ._average_model , pl_module )
@@ -267,17 +274,17 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
267274 @override
268275 def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
269276 # the trainer increases the current epoch before this hook is called
270- if self ._model_contains_batch_norm and trainer .current_epoch - 1 == self .swa_end + 1 :
277+ if self ._model_contains_batch_norm and self . _max_epochs != - 1 and trainer .current_epoch - 1 == self .swa_end + 1 :
271278 # BatchNorm epoch update. Reset state
272279 trainer .accumulate_grad_batches = self ._accumulate_grad_batches
273280 trainer .fit_loop .max_batches -= 1
274281 assert trainer .fit_loop .max_epochs is not None
275282 trainer .fit_loop .max_epochs -= 1
276283 self .reset_momenta ()
277- elif trainer .current_epoch - 1 == self .swa_end :
278- # Last SWA epoch. Transfer weights from average model to pl_module
279- assert self ._average_model is not None
280- self .transfer_weights (self ._average_model , pl_module )
284+ elif trainer .current_epoch - 1 == self .swa_end or self . _max_epochs == - 1 :
285+ # Last SWA epoch or infinite training . Transfer weights from average model to pl_module
286+ if self ._average_model is not None :
287+ self .transfer_weights (self ._average_model , pl_module )
281288
282289 @staticmethod
283290 def transfer_weights (src_pl_module : "pl.LightningModule" , dst_pl_module : "pl.LightningModule" ) -> None :
0 commit comments