Skip to content

Commit 7239860

Browse files
carmoccalexierule
authored andcommitted
Fix SWA with a list of learning rates (#8747)
* Fix swa lrs - needs test * Add test * Update CHANGELOG
1 parent fa6720a commit 7239860

File tree

3 files changed

+50
-24
lines changed

3 files changed

+50
-24
lines changed

CHANGELOG.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [1.4.3] - 2021-08-17
99

10+
11+
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
1012
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
13+
- Fixed `StochasticWeightAveraging` with a list of learning rates not applying them to each param group ([#8747](https://github.com/PyTorchLightning/pytorch-lightning/issues/8747))
1114

1215
## [1.4.2] - 2021-08-10
1316

@@ -29,12 +32,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2932
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))
3033

3134

32-
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
33-
34-
35-
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
36-
37-
3835
## [1.4.0] - 2021-07-27
3936

4037
### Added

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
164164
# move average model to request device.
165165
self._average_model = self._average_model.to(self._device or pl_module.device)
166166

167-
optimizers = trainer.optimizers
167+
optimizer = trainer.optimizers[0]
168+
if self._swa_lrs is None:
169+
self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
170+
if isinstance(self._swa_lrs, float):
171+
self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)
168172

169-
for param_group in optimizers[0].param_groups:
170-
if self._swa_lrs is None:
171-
initial_lr = param_group["lr"]
172-
173-
elif isinstance(self._swa_lrs, float):
174-
initial_lr = self._swa_lrs
175-
176-
else:
177-
initial_lr = self._swa_lrs[0]
178-
179-
param_group["initial_lr"] = initial_lr
180-
181-
self._swa_lrs = initial_lr
173+
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
174+
group["initial_lr"] = lr
182175

183176
self._swa_scheduler = SWALR(
184-
optimizers[0],
185-
swa_lr=initial_lr,
177+
optimizer,
178+
swa_lr=self._swa_lrs,
186179
anneal_epochs=self._annealing_epochs,
187180
anneal_strategy=self._annealing_strategy,
188181
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def configure_optimizers(self):
210210
)
211211
trainer.fit(model)
212212
if use_callbacks or stochastic_weight_avg:
213-
assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1
214-
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
213+
assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1
214+
assert trainer.callbacks[0]._swa_lrs == [1e-3 if use_callbacks else 0.1]
215215
else:
216216
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)
217217

@@ -237,3 +237,39 @@ def on_before_accelerator_backend_setup(self, trainer: "Trainer", pl_module: "Li
237237
trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True)
238238
trainer.fit(model, train_dataloader=DataLoader(RandomDataset(32, 2)))
239239
assert swa.on_before_accelerator_backend_setup_called
240+
241+
242+
def test_swa_multiple_lrs(tmpdir):
243+
swa_lrs = [0.123, 0.321]
244+
245+
class TestModel(BoringModel):
246+
def __init__(self):
247+
super(BoringModel, self).__init__()
248+
self.layer1 = torch.nn.Linear(32, 32)
249+
self.layer2 = torch.nn.Linear(32, 2)
250+
251+
def forward(self, x):
252+
x = self.layer1(x)
253+
x = self.layer2(x)
254+
return x
255+
256+
def configure_optimizers(self):
257+
params = [{"params": self.layer1.parameters(), "lr": 0.1}, {"params": self.layer2.parameters(), "lr": 0.2}]
258+
return torch.optim.Adam(params)
259+
260+
def on_train_epoch_start(self):
261+
optimizer = trainer.optimizers[0]
262+
assert [pg["lr"] for pg in optimizer.param_groups] == [0.1, 0.2]
263+
assert [pg["initial_lr"] for pg in optimizer.param_groups] == swa_lrs
264+
assert [pg["swa_lr"] for pg in optimizer.param_groups] == swa_lrs
265+
self.on_train_epoch_start_called = True
266+
267+
model = TestModel()
268+
swa_callback = StochasticWeightAveraging(swa_lrs=swa_lrs)
269+
trainer = Trainer(
270+
default_root_dir=tmpdir,
271+
callbacks=swa_callback,
272+
fast_dev_run=1,
273+
)
274+
trainer.fit(model)
275+
assert model.on_train_epoch_start_called

0 commit comments

Comments
 (0)