Skip to content

Commit 0e1c38f

Browse files
committed
fix: Ensure model weights update correctly with mixed precision training (16-mixed)
1 parent be608fa commit 0e1c38f

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from lightning.pytorch.utilities import rank_zero_warn
3+
4+
5+
def optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs):
6+
"""Performs the actual optimizer step with proper gradient scaling."""
7+
scaler = self.scaler
8+
9+
# Scale loss and compute gradients
10+
if closure is not None:
11+
with torch.cuda.amp.autocast():
12+
loss = closure()
13+
scaler.scale(loss).backward()
14+
15+
try:
16+
# Unscale gradients before optimizer step
17+
scaler.unscale_(optimizer)
18+
19+
# Check if gradients are finite
20+
valid_gradients = True
21+
for param_group in optimizer.param_groups:
22+
for param in param_group["params"]:
23+
if param.grad is not None and not torch.isfinite(param.grad).all():
24+
valid_gradients = False
25+
break
26+
if not valid_gradients:
27+
break
28+
29+
if valid_gradients:
30+
# If gradients are valid, step optimizer and update scaler
31+
optimizer.step()
32+
scaler.update()
33+
else:
34+
# Skip step and adjust scaler
35+
scaler.update()
36+
rank_zero_warn(
37+
"Gradients have become NaN or inf. Skipping optimizer step but updating scaler. "
38+
"This may affect model convergence.",
39+
category=RuntimeWarning,
40+
)
41+
except RuntimeError as e:
42+
if "unscale_() has already been called" not in str(e):
43+
raise
44+
# Handle case where unscale was already called
45+
optimizer.step()
46+
scaler.update()
47+
48+
optimizer.zero_grad()

tests/tests_pytorch/loops/optimization/test_manual_loop.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,38 @@ def training_step(self, batch, batch_idx):
4242

4343
with pytest.raises(MisconfigurationException, match="return a Tensor or have no return"):
4444
trainer.fit(model)
45+
46+
47+
def test_amp_training_updates_weights(tmp_path):
48+
"""Test that model weights are properly updated with mixed precision training."""
49+
50+
class TestModel(BoringModel):
51+
def __init__(self):
52+
super().__init__()
53+
self.previous_params = None
54+
self.layer = torch.nn.Linear(32, 32) # Same input/output size
55+
56+
def training_step(self, batch, batch_idx):
57+
# Track parameter changes
58+
params = torch.cat([param.view(-1) for param in self.parameters()])
59+
if self.previous_params is not None:
60+
num_different_values = (self.previous_params != params).sum().item()
61+
assert num_different_values > 0, f"Parameters did not update at step {batch_idx}"
62+
self.previous_params = params.clone().detach()
63+
64+
# Regular training step
65+
x = batch[0]
66+
output = self.layer(x)
67+
loss = torch.nn.functional.mse_loss(output, x) # Autoencoder-style loss
68+
return loss
69+
70+
model = TestModel()
71+
trainer = Trainer(
72+
default_root_dir=tmp_path,
73+
max_epochs=1,
74+
limit_train_batches=10,
75+
precision="16-mixed",
76+
accelerator="auto",
77+
devices=1,
78+
)
79+
trainer.fit(model)

0 commit comments

Comments
 (0)