Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9d92a4a
Fix LR scheduler behaviour with AMP
milesial Jan 3, 2023
2c2b138
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2023
1b3365e
Fix LR schedulers when optimizers with frequencies
milesial Jan 9, 2023
a237283
Move implementation to scales comparison
milesial Jan 4, 2023
183a6a6
Catch warnings
carmocca Jan 10, 2023
0406a55
Fix implementation
milesial Jan 10, 2023
c44a892
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2023
77043b0
Merge branch 'master' into master
Borda Jan 12, 2023
78fa3af
Merge branch 'master' into master
Borda Jan 13, 2023
74fab13
Merge branch 'master' into master
Borda Jan 14, 2023
e5a91cc
Merge branch 'master' into milesial/master
Borda Feb 3, 2023
11e4f72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2023
406d1bd
Merge branch 'master' into milesial/master
Borda Mar 3, 2023
6b0f59b
Merge branch 'master' into master
Borda Mar 24, 2023
33454f4
Merge branch 'master' into master
Borda Apr 14, 2023
f5353fc
Merge branch 'master' into master
Borda Apr 18, 2023
9d7b2b8
Merge branch 'master' into master
Borda Apr 26, 2023
b997075
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2023
930ba4c
Merge branch 'master' into master
Borda Apr 26, 2023
13f5fb4
Merge branch 'master' into master
Borda May 11, 2023
374e856
Merge branch 'master' into master
Borda May 23, 2023
b900ded
Merge branch 'master' into master
Borda Nov 18, 2023
f10b897
Merge branch 'master' into milesial/master
Borda Feb 16, 2024
965fc03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2024
418b8d4
Merge branch 'master' into master
Borda Aug 19, 2025
992de00
Merge branch 'master' into master
Borda Sep 2, 2025
5bff40a
Merge branch 'master' into master
Borda Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def optimizer_step(
return super().optimizer_step(optimizer, **kwargs)
if isinstance(optimizer, LBFGS):
raise TypeError("AMP and the LBFGS optimizer are not compatible.")
previous_scale = self.scaler.get_scale()
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
return step_output

def state_dict(self) -> Dict[str, Any]:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool)
)
continue

if getattr(self.trainer.optimizers[config.opt_idx], "_skip_next_scheduler_step", False):
continue

self.scheduler_progress.increment_ready()

# update LR
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def optimizer_step( # type: ignore[override]
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
previous_scale = self.scaler.get_scale()
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
return step_output
return closure_result

Expand Down
1 change: 1 addition & 0 deletions tests/tests_fabric/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_amp_precision_backward():
def test_amp_precision_optimizer_step_with_scaler():
precision = MixedPrecision(precision="16-mixed", device="cuda")
precision.scaler = Mock()
precision.scaler.get_scale = Mock(return_value=1.0)
optimizer = Mock()

precision.optimizer_step(optimizer, keyword="arg")
Expand Down
6 changes: 1 addition & 5 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,7 @@ def _auto_train_batch(
name="optimizer_step",
args=(current_epoch, i, ANY, ANY),
),
*(
[dict(name="lr_scheduler_step", args=(ANY, None))]
if i == (trainer.num_training_batches - 1)
else []
),
*([dict(name="lr_scheduler_step", args=ANY)] if i == (trainer.num_training_batches - 1) else []),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)),
]
Expand Down