Skip to content

Commit c76cc23

Browse files
tchatonrohitgr7
andauthored
[bugfix] Resolve bug with multiple optimizers and toggle. (#5574)
* fix toggle_optimizer * update doc * resolve bug * update * Update pytorch_lightning/core/lightning.py Co-authored-by: Rohit Gupta <[email protected]> * update on comments * update on comments * update Co-authored-by: Rohit Gupta <[email protected]>
1 parent e87424a commit c76cc23

File tree

4 files changed

+118
-9
lines changed

4 files changed

+118
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
### Fixed
2828

29+
- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574))
30+
31+
2932
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
3033

3134

@@ -63,7 +66,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6366
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))
6467
- Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505))
6568

66-
6769
## [1.1.3] - 2021-01-05
6870

6971
### Added

pytorch_lightning/core/lightning.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,16 +1170,47 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11701170
11711171
Override for your own behavior
11721172
1173+
It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset.
1174+
11731175
Args:
1174-
optimizer:
1175-
optimizer_idx:
1176+
optimizer: Current optimizer used in training_loop
1177+
optimizer_idx: Current optimizer idx in training_loop
11761178
"""
1177-
for param in self.parameters():
1178-
param.requires_grad = False
1179+
param_requires_grad_state = {}
1180+
# make sure current optimizer is latest to be iterated over.
1181+
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
1182+
num_optimizers = len(optimizers) - 1
1183+
for opt_idx, opt in enumerate(optimizers):
1184+
for group in opt.param_groups:
1185+
for param in group['params']:
1186+
if num_optimizers == opt_idx:
1187+
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
1188+
if param in param_requires_grad_state:
1189+
param.requires_grad = param_requires_grad_state[param]
1190+
else:
1191+
# save requires_grad for later restoration
1192+
param_requires_grad_state[param] = param.requires_grad
1193+
param.requires_grad = False
1194+
1195+
self._param_requires_grad_state = param_requires_grad_state
1196+
1197+
def untoggle_optimizer(self, optimizer_idx: int):
1198+
"""
1199+
.. note:: Only called when using multiple optimizers
11791200
1180-
for group in optimizer.param_groups:
1181-
for param in group['params']:
1182-
param.requires_grad = True
1201+
Override for your own behavior
1202+
1203+
Args:
1204+
optimizer_idx: Current optimizer idx in training_loop
1205+
"""
1206+
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
1207+
if optimizer_idx != opt_idx:
1208+
for group in opt.param_groups:
1209+
for param in group['params']:
1210+
if param in self._param_requires_grad_state:
1211+
param.requires_grad = self._param_requires_grad_state[param]
1212+
# save memory
1213+
del self._param_requires_grad_state
11831214

11841215
def optimizer_step(
11851216
self,

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
798798
if self.trainer.terminate_on_nan:
799799
self.trainer.detect_nan_tensors(result.loss)
800800

801+
if len(self.trainer.optimizers) > 1:
802+
# revert back to previous state
803+
self.trainer.get_model().untoggle_optimizer(opt_idx)
804+
801805
return result
802806

803807
def backward(self, result, optimizer, opt_idx, *args, **kwargs):

tests/core/test_lightning_module.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from argparse import ArgumentParser
1514
import pickle
15+
from argparse import ArgumentParser
1616
from typing import Optional
1717
from unittest.mock import MagicMock, patch
1818

1919
import pytest
2020
import torch
21+
from torch import nn
2122
from torch.optim import Adam, SGD
2223
from torch.utils.data import DataLoader, random_split
2324

@@ -139,3 +140,74 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
139140
)
140141

141142
trainer.fit(model)
143+
144+
145+
def test_toggle_untoggle(tmpdir):
146+
147+
class TestModel(BoringModel):
148+
149+
def training_step(self, batch, batch_idx, optimizer_idx=None):
150+
return super().training_step(batch, batch_idx)
151+
152+
def __init__(self):
153+
super().__init__()
154+
self.layer_1 = nn.Sequential(
155+
nn.Linear(32, 32),
156+
nn.ReLU(),
157+
nn.Linear(32, 32),
158+
nn.ReLU(),
159+
nn.Linear(32, 32),
160+
)
161+
162+
self.layer_2 = nn.Sequential(
163+
nn.ReLU(),
164+
nn.Linear(32, 32),
165+
nn.ReLU(),
166+
nn.Linear(32, 32),
167+
nn.ReLU(),
168+
nn.Linear(32, 2)
169+
)
170+
171+
# set some weights to False to check untoggle works as expected.
172+
self.layer_1[2].weight.requires_grad = False
173+
self.layer_1[4].weight.requires_grad = False
174+
175+
self.layer_2[1].weight.requires_grad = False
176+
self.layer_2[3].weight.requires_grad = False
177+
178+
def configure_optimizers(self):
179+
optimizer = SGD(self.layer_1.parameters(), lr=0.1)
180+
optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1)
181+
return [optimizer, optimizer_2]
182+
183+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
184+
if optimizer_idx == 0:
185+
assert self.layer_1[0].weight.requires_grad is True
186+
assert self.layer_1[2].weight.requires_grad is False
187+
assert self.layer_1[4].weight.requires_grad is False
188+
189+
assert self.layer_2[1].weight.requires_grad is False
190+
assert self.layer_2[3].weight.requires_grad is False
191+
assert self.layer_2[5].weight.requires_grad is False
192+
193+
if optimizer_idx == 1:
194+
assert self.layer_1[0].weight.requires_grad is False
195+
assert self.layer_1[2].weight.requires_grad is False
196+
assert self.layer_1[4].weight.requires_grad is False
197+
198+
assert self.layer_2[1].weight.requires_grad is False
199+
assert self.layer_2[3].weight.requires_grad is False
200+
assert self.layer_2[5].weight.requires_grad is True
201+
optimizer.step(closure=closure)
202+
203+
model = TestModel()
204+
model.training_epoch_end = None
205+
206+
trainer = Trainer(
207+
max_epochs=1,
208+
default_root_dir=tmpdir,
209+
limit_train_batches=8,
210+
accumulate_grad_batches=1,
211+
)
212+
213+
trainer.fit(model)

0 commit comments

Comments
 (0)