Skip to content

Commit 840b7ce

Browse files
authored
[optim] Handle the case when param groups are passed to optimizer (#2606)
fix param group
1 parent 30f5850 commit 840b7ce

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

test/test_low_bit_optim.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,27 @@ def test_optim_default_dtype_bf16(self, optim_name, device):
187187
finally:
188188
torch.set_default_dtype(old_dtype)
189189

190+
@parametrize("optim_name", ["Adam8bit", "Adam4bit", "AdamFp8"])
191+
@parametrize("device", _DEVICES)
192+
def test_param_groups(self, optim_name, device):
193+
if optim_name.endswith("Fp8") and device == "cuda":
194+
if torch.cuda.get_device_capability() < (8, 9):
195+
pytest.skip("FP8 CUDA requires compute capability >= 8.9")
196+
197+
model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
198+
model.to(device=device)
199+
param_groups = [
200+
dict(params=list(model[0].parameters()), lr=1e-4),
201+
dict(params=list(model[2].parameters()), lr=1e-5),
202+
]
203+
optimizer = getattr(optim, optim_name)(param_groups)
204+
205+
x = torch.randn(4, 32, device=device)
206+
loss = model(x).sum()
207+
loss.backward()
208+
optimizer.step()
209+
optimizer.zero_grad()
210+
190211
# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
191212
# however, it's cumbersome to test it directly, since we would need to run distributed
192213
# test 2 times with different world size, and persist checkpoint across the 2 runs.

torchao/optim/adam.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
if not 0.0 <= betas[1] < 1.0:
4040
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
4141
defaults = dict(
42-
lr=torch.tensor(lr),
42+
lr=lr,
4343
betas=betas,
4444
eps=eps,
4545
weight_decay=weight_decay,
@@ -50,6 +50,14 @@ def __init__(
5050
self.bf16_stochastic_round = bf16_stochastic_round
5151
self.is_adamw = is_adamw
5252

53+
def add_param_group(self, param_group: dict) -> None:
54+
super().add_param_group(param_group)
55+
56+
# convert LR to a tensor
57+
group = self.param_groups[-1]
58+
if not isinstance(group["lr"], Tensor):
59+
group["lr"] = torch.tensor(group["lr"], dtype=torch.float32)
60+
5361
def __setstate__(self, state):
5462
super().__setstate__(state)
5563
for group in self.param_groups:

0 commit comments

Comments
 (0)