Skip to content

Commit 6d63651

Browse files
awaelchlilexierule
authored andcommitted
Refresh the internal LightningOptimizer state for inspection (#18280)
(cherry picked from commit 4da2d87)
1 parent 92d689e commit 6d63651

File tree

4 files changed

+72
-14
lines changed

4 files changed

+72
-14
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [UnReleased] - 2023-08-DD
99

10+
- Added `LightningOptimizer.refresh()` to update the `__dict__` in case the optimizer it wraps has changed its internal state ([#18280](https://github.com/Lightning-AI/lightning/pull/18280))
11+
12+
1013
### Changed
1114

1215
- Disabled the auto-detection of the Kubeflow environment ([#18137](https://github.com/Lightning-AI/lightning/pull/18137))
@@ -31,6 +34,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3134
- Ensure that the closure running inside the optimizer step has gradients enabled, even if the optimizer step has it disabled ([#18268](https://github.com/Lightning-AI/lightning/pull/18268))
3235

3336

37+
- Fixed an issue that could cause the `LightningOptimizer` wrapper returned by `LightningModule.optimizers()` have different internal state than the optimizer it wraps ([#18280](https://github.com/Lightning-AI/lightning/pull/18280))
38+
39+
40+
3441
## [2.0.5] - 2023-07-07
3542

3643
### Fixed

src/lightning/pytorch/core/module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
159159
opts: MODULE_OPTIMIZERS = self._fabric_optimizers
160160
elif use_pl_optimizer:
161161
opts = self.trainer.strategy._lightning_optimizers
162+
for opt in opts:
163+
opt.refresh()
162164
else:
163165
opts = self.trainer.optimizers
164166

src/lightning/pytorch/core/optimizer.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,15 @@ def do_nothing_closure() -> None:
3434

3535
class LightningOptimizer:
3636
"""This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across
37-
accelerators, AMP, accumulate_grad_batches."""
37+
accelerators, AMP, accumulate_grad_batches.
38+
39+
Note: The purpose of this wrapper is only to define new methods and redirect the `.step()` call. The internal
40+
state ``__dict__`` is not kept in sync with the internal state of the original optimizer, but the Trainer never
41+
relies on the internal state of the wrapper.
42+
43+
"""
3844

3945
def __init__(self, optimizer: Optimizer):
40-
# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has
41-
# implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`
42-
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
4346
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
4447

4548
self._optimizer = optimizer
@@ -48,20 +51,12 @@ def __init__(self, optimizer: Optimizer):
4851
self._on_before_step = do_nothing_closure
4952
self._on_after_step = do_nothing_closure
5053

54+
self.refresh()
55+
5156
@property
5257
def optimizer(self) -> Optimizer:
5358
return self._optimizer
5459

55-
@classmethod
56-
def _to_lightning_optimizer(
57-
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy"
58-
) -> "LightningOptimizer":
59-
# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
60-
# tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]
61-
lightning_optimizer = optimizer if isinstance(optimizer, LightningOptimizer) else cls(optimizer)
62-
lightning_optimizer._strategy = proxy(strategy)
63-
return lightning_optimizer
64-
6560
@contextmanager
6661
def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
6762
"""This function is just a helper for advanced users.
@@ -85,6 +80,15 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
8580
yield
8681
lightning_module.untoggle_optimizer(self)
8782

83+
def refresh(self) -> None:
84+
"""Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer.
85+
86+
This is only needed to present the user with an updated view in case they inspect the state of this wrapper.
87+
"""
88+
# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has
89+
# implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`
90+
self.__dict__.update({k: v for k, v in self.optimizer.__dict__.items() if k not in ("step", "__del__")})
91+
8892
def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
8993
"""Performs a single optimization step (parameter update).
9094
@@ -160,6 +164,16 @@ def closure_dis():
160164

161165
return step_output
162166

167+
@classmethod
168+
def _to_lightning_optimizer(
169+
cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy"
170+
) -> "LightningOptimizer":
171+
# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
172+
# tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]
173+
lightning_optimizer = optimizer if isinstance(optimizer, LightningOptimizer) else cls(optimizer)
174+
lightning_optimizer._strategy = proxy(strategy)
175+
return lightning_optimizer
176+
163177

164178
def _init_optimizers_and_lr_schedulers(
165179
model: "pl.LightningModule",

tests/tests_pytorch/core/test_lightning_optimizer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from copy import deepcopy
1415
from unittest.mock import DEFAULT, Mock, patch
1516

1617
import pytest
@@ -160,6 +161,32 @@ def test_state():
160161
assert optimizer.state == lightning_optimizer.state
161162

162163

164+
def test_state_mutation():
165+
model = torch.nn.Linear(3, 4)
166+
optimizer0 = torch.optim.Adam(model.parameters(), lr=0.1)
167+
lightning_optimizer0 = LightningOptimizer(optimizer0)
168+
169+
optimizer0.param_groups[0]["lr"] = 1.0
170+
assert lightning_optimizer0.param_groups[0]["lr"] == 1.0
171+
172+
# Load state into the unwrapped optimizer
173+
state_dict0 = deepcopy(optimizer0.state_dict())
174+
optimizer1 = torch.optim.Adam(model.parameters(), lr=100)
175+
lightning_optimizer1 = LightningOptimizer(optimizer1)
176+
optimizer1.load_state_dict(state_dict0)
177+
178+
# LightningOptimizer needs to be refreshed to see the new state
179+
assert lightning_optimizer1.param_groups[0]["lr"] != 1.0
180+
lightning_optimizer1.refresh()
181+
assert lightning_optimizer1.param_groups[0]["lr"] == 1.0
182+
183+
# Load state into wrapped optimizer
184+
optimizer2 = torch.optim.Adam(model.parameters(), lr=100)
185+
lightning_optimizer2 = LightningOptimizer(optimizer2)
186+
lightning_optimizer2.load_state_dict(state_dict0)
187+
assert lightning_optimizer2.param_groups[0]["lr"] == 1.0
188+
189+
163190
def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
164191
"""Test overriding zero_grad works in automatic_optimization."""
165192

@@ -296,7 +323,15 @@ def test_lightning_optimizer_keeps_hooks():
296323

297324
def test_params_groups_and_state_are_accessible(tmpdir):
298325
class TestModel(BoringModel):
326+
def on_train_start(self):
327+
# Update the learning rate manually on the unwrapped optimizer
328+
assert not isinstance(self.trainer.optimizers[0], LightningOptimizer)
329+
self.trainer.optimizers[0].param_groups[0]["lr"] = 2.0
330+
299331
def training_step(self, batch, batch_idx):
332+
opt = self.optimizers()
333+
assert opt.param_groups[0]["lr"] == 2.0
334+
300335
loss = self.step(batch)
301336
self.__loss = loss
302337
return loss

0 commit comments

Comments
 (0)