Skip to content

Commit 6b9e0a5

Browse files
awaelchlipre-commit-ci[bot]
authored andcommitted
Support skipping training step when using mixed precision training (#18267)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 97020bf)
1 parent 6d63651 commit 6b9e0a5

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
3232

3333

34+
- Fixed the gradient unscaling logic if the training step skipped backward (by returning `None`) ([#18267](https://github.com/Lightning-AI/lightning/pull/18267))
35+
36+
3437
- Ensure that the closure running inside the optimizer step has gradients enabled, even if the optimizer step has it disabled ([#18268](https://github.com/Lightning-AI/lightning/pull/18268))
3538

3639

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,19 @@ def optimizer_step( # type: ignore[override]
7575
raise MisconfigurationException("AMP and the LBFGS optimizer are not compatible.")
7676
closure_result = closure()
7777

78-
if not _optimizer_handles_unscaling(optimizer):
78+
# If backward was skipped in automatic optimization (return None), unscaling is not needed
79+
skip_unscaling = closure_result is None and model.automatic_optimization
80+
81+
if not _optimizer_handles_unscaling(optimizer) and not skip_unscaling:
7982
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
8083
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
8184
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
8285
self.scaler.unscale_(optimizer)
8386

8487
self._after_closure(model, optimizer)
85-
skipped_backward = closure_result is None
88+
8689
# in manual optimization, the closure does not return a value
87-
if not model.automatic_optimization or not skipped_backward:
90+
if not skip_unscaling:
8891
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
8992
step_output = self.scaler.step(optimizer, **kwargs)
9093
self.scaler.update()

tests/tests_pytorch/plugins/precision/test_amp_integration.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import Mock
15+
1416
import torch
1517

1618
from lightning.fabric import seed_everything
1719
from lightning.pytorch import Trainer
1820
from lightning.pytorch.demos.boring_classes import BoringModel
21+
from lightning.pytorch.plugins.precision import MixedPrecisionPlugin
1922
from tests_pytorch.helpers.runif import RunIf
2023

2124

@@ -54,3 +57,31 @@ def run(fused=False):
5457
# Both the regular and the fused version of Adam produce the same losses and model weights
5558
for p, q in zip(params, params_fused):
5659
torch.testing.assert_close(p, q)
60+
61+
62+
@RunIf(min_cuda_gpus=1)
63+
def test_skip_training_step_with_grad_scaler():
64+
"""Test that the grad scaler gets skipped when skipping a training step."""
65+
66+
class TestModel(BoringModel):
67+
def training_step(self, batch, batch_idx):
68+
if batch_idx % 2:
69+
return None # skipping the backward should skip the grad scaler too
70+
return super().training_step(batch, batch_idx)
71+
72+
trainer = Trainer(
73+
accelerator="cuda",
74+
devices=1,
75+
precision="16-mixed",
76+
barebones=True,
77+
max_steps=5,
78+
gradient_clip_val=0.5,
79+
)
80+
assert isinstance(trainer.precision_plugin, MixedPrecisionPlugin)
81+
assert trainer.precision_plugin.scaler is not None
82+
trainer.precision_plugin.scaler = Mock(wraps=trainer.precision_plugin.scaler)
83+
model = TestModel()
84+
trainer.fit(model)
85+
assert trainer.precision_plugin.scaler.unscale_.call_count == 3
86+
assert trainer.precision_plugin.scaler.step.call_count == 3
87+
assert trainer.precision_plugin.scaler.update.call_count == 3

0 commit comments

Comments
 (0)