Skip to content

Commit c63e9f7

Browse files
committed
implement special case max_epoch==-1
1 parent 79ffe50 commit c63e9f7

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/lightning/pytorch/callbacks/stochastic_weight_avg.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def swa_start(self) -> int:
139139

140140
@property
141141
def swa_end(self) -> int:
142+
if self._max_epochs == -1:
143+
return float("inf") # type: ignore[return-value]
142144
return self._max_epochs - 1 # 0-based
143145

144146
@staticmethod
@@ -163,12 +165,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
163165

164166
assert trainer.max_epochs is not None
165167
if isinstance(self._swa_epoch_start, float):
168+
if trainer.max_epochs == -1:
169+
raise MisconfigurationException(
170+
"SWA with `swa_epoch_start` as a float is not supported when `max_epochs=-1`. "
171+
"Please provide `swa_epoch_start` as an integer."
172+
)
166173
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
167174

168175
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)
169176

170177
self._max_epochs = trainer.max_epochs
171-
if self._model_contains_batch_norm:
178+
if self._model_contains_batch_norm and trainer.max_epochs != -1:
172179
# virtually increase max_epochs to perform batch norm update on latest epoch.
173180
assert trainer.fit_loop.max_epochs is not None
174181
trainer.fit_loop.max_epochs += 1
@@ -243,7 +250,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
243250
self._latest_update_epoch = trainer.current_epoch
244251

245252
# Note: No > here in case the callback is saved with the model and training continues
246-
if trainer.current_epoch == self.swa_end + 1:
253+
if self._max_epochs != -1 and trainer.current_epoch == self.swa_end + 1:
247254
# Transfer weights from average model to pl_module
248255
assert self._average_model is not None
249256
self.transfer_weights(self._average_model, pl_module)
@@ -267,17 +274,17 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
267274
@override
268275
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
269276
# the trainer increases the current epoch before this hook is called
270-
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
277+
if self._model_contains_batch_norm and self._max_epochs != -1 and trainer.current_epoch - 1 == self.swa_end + 1:
271278
# BatchNorm epoch update. Reset state
272279
trainer.accumulate_grad_batches = self._accumulate_grad_batches
273280
trainer.fit_loop.max_batches -= 1
274281
assert trainer.fit_loop.max_epochs is not None
275282
trainer.fit_loop.max_epochs -= 1
276283
self.reset_momenta()
277-
elif trainer.current_epoch - 1 == self.swa_end:
278-
# Last SWA epoch. Transfer weights from average model to pl_module
279-
assert self._average_model is not None
280-
self.transfer_weights(self._average_model, pl_module)
284+
elif trainer.current_epoch - 1 == self.swa_end or self._max_epochs == -1:
285+
# Last SWA epoch or infinite training. Transfer weights from average model to pl_module
286+
if self._average_model is not None:
287+
self.transfer_weights(self._average_model, pl_module)
281288

282289
@staticmethod
283290
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None:

0 commit comments

Comments
 (0)