diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..4b25fd8061837 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -20,6 +20,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Fixed + +- Fix `StochasticWeightAveraging` with infinite epochs ([#21396](https://github.com/Lightning-AI/pytorch-lightning/pull/21396)) + ## [2.6.0] - 2025-11-28 diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 79c5423c54084..9b051b96d0027 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -139,6 +139,8 @@ def swa_start(self) -> int: @property def swa_end(self) -> int: + if self._max_epochs == -1: + return float("inf") # type: ignore[return-value] return self._max_epochs - 1 # 0-based @staticmethod @@ -163,12 +165,17 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - assert trainer.max_epochs is not None if isinstance(self._swa_epoch_start, float): + if trainer.max_epochs == -1: + raise MisconfigurationException( + "SWA with `swa_epoch_start` as a float is not supported when `max_epochs=-1`. " + "Please provide `swa_epoch_start` as an integer." + ) self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) self._max_epochs = trainer.max_epochs - if self._model_contains_batch_norm: + if self._model_contains_batch_norm and trainer.max_epochs != -1: # virtually increase max_epochs to perform batch norm update on latest epoch. assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs += 1 @@ -243,7 +250,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues - if trainer.current_epoch == self.swa_end + 1: + if self._max_epochs != -1 and trainer.current_epoch == self.swa_end + 1: # Transfer weights from average model to pl_module assert self._average_model is not None self.transfer_weights(self._average_model, pl_module) @@ -267,17 +274,17 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: @override def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # the trainer increases the current epoch before this hook is called - if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: + if self._model_contains_batch_norm and self._max_epochs != -1 and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches trainer.fit_loop.max_batches -= 1 assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs -= 1 self.reset_momenta() - elif trainer.current_epoch - 1 == self.swa_end: - # Last SWA epoch. Transfer weights from average model to pl_module - assert self._average_model is not None - self.transfer_weights(self._average_model, pl_module) + elif trainer.current_epoch - 1 == self.swa_end or self._max_epochs == -1: + # Last SWA epoch or infinite training. Transfer weights from average model to pl_module + if self._average_model is not None: + self.transfer_weights(self._average_model, pl_module) @staticmethod def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index c63dd4e5c2ac9..242679b7d559c 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -387,5 +387,35 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str): trainer.fit(model) +def test_swa_with_infinite_epochs_and_batchnorm(tmp_path): + """Test that SWA works correctly with max_epochs=-1 (infinite training) and BatchNorm.""" + model = SwaTestModel(batchnorm=True) + swa_callback = StochasticWeightAveraging(swa_lrs=0.1, swa_epoch_start=2) + + trainer = Trainer( + default_root_dir=tmp_path, + enable_progress_bar=False, + enable_model_summary=False, + max_epochs=-1, + max_steps=30, # Use max_steps as stopping condition + limit_train_batches=5, + limit_val_batches=0, + callbacks=[swa_callback], + logger=False, + ) + assert trainer.max_epochs == -1 + assert trainer.fit_loop.max_epochs == -1 + + trainer.fit(model) + assert trainer.current_epoch >= 5 + assert trainer.global_step == 30 + assert trainer.max_epochs == -1 + + # Verify SWA was actually applied (update_parameters should have been called) + # SWA starts at epoch 2, so with 6 epochs (0-5), we should have 4 updates (epochs 2, 3, 4, 5) + assert swa_callback.n_averaged is not None + assert swa_callback.n_averaged > 0, "SWA should have updated parameters" + + def _backward_patch(trainer: Trainer) -> AbstractContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)