Skip to content

Commit 7c7a4ba

Browse files
adamreevecarmoccaawaelchli
authored
Fix SWA LR scheduler not being stepped (#12446)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 648cc2d commit 7c7a4ba

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
179179
anneal_strategy=self._annealing_strategy,
180180
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
181181
)
182-
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler)
182+
# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
183+
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
183184
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
184185

185186
if trainer.lr_scheduler_configs:

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def on_train_epoch_end(self, trainer, *args):
8787
if self.swa_start <= trainer.current_epoch <= self.swa_end:
8888
swa_epoch = trainer.current_epoch - self.swa_start
8989
assert self.n_averaged == swa_epoch + 1
90+
# Scheduler is stepped once on initialization and then at the end of each epoch
91+
assert self._swa_scheduler._step_count == swa_epoch + 2
9092
elif trainer.current_epoch > self.swa_end:
9193
assert self.n_averaged == self._max_epochs - self.swa_start
9294

0 commit comments

Comments
 (0)