Skip to content

Commit a69b940

Browse files
authored
Merge pull request #9606 from PyTorchLightning/bugfix/move_grad_tracking
[Bugfix] Fix location of `unscale` in mixed precision plugin
1 parent 25bfd06 commit a69b940

File tree

5 files changed

+68
-7
lines changed

5 files changed

+68
-7
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,20 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [unreleased] - 2021-??-??
9+
10+
- Moved the gradient unscaling in `NativeMixedPrecisionPlugin` from `pre_optimizer_step` to `post_backward` ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))
11+
- Fixed gradient unscaling being called too late, causing gradient clipping and gradient norm tracking to be applied incorrectly ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))
12+
13+
814
## [1.4.8] - 2021-09-22
915

1016
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)
1117
- Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)
1218
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
1319
- Fixed back-compatibility for saving hyperparameters from a single container and inferring its argument name by reverting [#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125) ([#9642](https://github.com/PyTorchLightning/pytorch-lightning/pull/9642))
1420

21+
1522
## [1.4.7] - 2021-09-14
1623

1724
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OU
264264
"""
265265
return self.training_type_plugin.validation_step_end(output)
266266

267-
def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
267+
def backward(self, closure_loss: Tensor, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any) -> Tensor:
268268
"""Forwards backward-calls to the precision plugin.
269269
270270
Args:
@@ -273,9 +273,9 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
273273
self.training_type_plugin.pre_backward(closure_loss)
274274
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
275275

276-
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
276+
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, *args, **kwargs)
277277

278-
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
278+
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss, optimizer)
279279
self.training_type_plugin.post_backward(closure_loss)
280280

281281
return closure_loss

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def pre_optimizer_step(
5555
" To request, please file a Github issue in PyTorch and tag @mcarilli"
5656
)
5757
result = lambda_closure() # native amp does not support closures
58-
self.scaler.unscale_(optimizer)
58+
if not model.automatic_optimization:
59+
# unscale in manual optimization as user does not rely on lightning
60+
# to call backward, but does call LightningOptimizer.step
61+
self.scaler.unscale_(optimizer)
5962
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
6063
skipped_backward = result is None
6164
# in manual optimization, the closure does not return a value
@@ -65,6 +68,15 @@ def pre_optimizer_step(
6568
self.scaler.update()
6669
return False
6770

71+
def post_backward(
72+
self, model: "pl.LightningModule", closure_loss: torch.Tensor, optimizer: Optimizer
73+
) -> torch.Tensor:
74+
ret_val = super().post_backward(model, closure_loss, optimizer)
75+
# unscale here to have it inside the closure before the grad tracking and clipping
76+
if model.automatic_optimization and not model.trainer.fit_loop.should_accumulate():
77+
self.scaler.unscale_(optimizer)
78+
return ret_val
79+
6880
@contextmanager
6981
def train_step_context(self) -> Generator[None, None, None]:
7082
"""Enable autocast context"""

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def backward(
7979
else:
8080
closure_loss.backward(*args, **kwargs)
8181

82-
def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor:
83-
"""Run after precision plugin executes backward
82+
def post_backward(
83+
self, model: "pl.LightningModule", closure_loss: Tensor, optimizer: torch.optim.Optimizer
84+
) -> Tensor:
85+
"""Run after precision plugin executes backward.
8486
8587
Args:
8688
model: the model to be optimized
@@ -89,6 +91,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
8991
# once backward has been applied, release graph
9092
closure_loss = closure_loss.detach()
9193
model.trainer.call_hook("on_after_backward")
94+
9295
return closure_loss
9396

9497
def pre_optimizer_step(

tests/plugins/test_amp_plugins.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import pytest
1919
import torch
2020

21-
from pytorch_lightning import Trainer
21+
from pytorch_lightning import seed_everything, Trainer
22+
from pytorch_lightning.core.optimizer import LightningOptimizer
2223
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
2324
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
2425
from tests.helpers import BoringModel
@@ -174,3 +175,41 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
174175
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
175176
model = BoringModel()
176177
trainer.fit(model)
178+
179+
180+
class GradientUnscaleNativeAMPPlugin(NativeMixedPrecisionPlugin):
181+
_was_scaled_finite = 0
182+
183+
def post_backward(self, model, closure_loss, optimizer) -> torch.Tensor:
184+
assert not isinstance(optimizer, LightningOptimizer)
185+
norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
186+
ret_val = super().post_backward(model, closure_loss, optimizer)
187+
norm_after = torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
188+
189+
# norm_after unscale should be smaller by scaling factor greater than 1
190+
if not (torch.isinf(norm_before) or torch.isnan(norm_before)):
191+
assert norm_after < norm_before * 10
192+
# during initial phase of finding the appropriate scaling, AMP skips optimizer steps that have
193+
# non-finite gradients; we count and assert that we had at least one finite gradient here
194+
self._was_scaled_finite += 1
195+
return ret_val
196+
197+
198+
@RunIf(min_gpus=1, amp_native=True)
199+
def test_correct_native_grad_unscaling(tmpdir):
200+
"""Test that the gradient clipping gets applied at the appropriate place when using mixed precision plugins."""
201+
seed_everything(42)
202+
plugin = GradientUnscaleNativeAMPPlugin()
203+
trainer = Trainer(
204+
default_root_dir=tmpdir,
205+
fast_dev_run=4,
206+
max_epochs=1,
207+
precision=16,
208+
amp_backend="native",
209+
gpus=1,
210+
plugins=plugin,
211+
)
212+
assert isinstance(trainer.precision_plugin, GradientUnscaleNativeAMPPlugin)
213+
model = BoringModel()
214+
trainer.fit(model)
215+
assert plugin._was_scaled_finite

0 commit comments

Comments
 (0)