Skip to content

Commit f4648b0

Browse files
authored
Merge pull request #267 from kozistr/fix/prodigy-optimizer
[Fix] device mismatch problems
2 parents 58f923f + fa94e7c commit f4648b0

File tree

9 files changed

+13
-10
lines changed

9 files changed

+13
-10
lines changed

docs/changelogs/v3.1.1.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010

1111
### Bug
1212

13-
* Fix to handle the optimizers that only take the `model` instead of the parameters in `create_optimizer()`. (#263)
13+
* Handle the optimizers that only take the `model` instead of the parameters in `create_optimizer()`. (#263)
14+
* Move the variable to the same device with the parameter. (#266, #267)

pytorch_optimizer/optimizer/adamg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
lr: float = 1e-3,
2828
betas: BETAS = (0.95, 0.999, 0.95),
2929
p: float = 0.5,
30-
q: float = 0.25,
30+
q: float = 0.24,
3131
weight_decay: float = 0.0,
3232
weight_decouple: bool = False,
3333
fixed_decay: bool = False,

pytorch_optimizer/optimizer/adan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
8383
if self.defaults['max_grad_norm'] == 0.0:
8484
return 1.0
8585

86-
global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
86+
global_grad_norm = get_global_gradient_norm(self.param_groups)
8787
global_grad_norm.sqrt_().add_(self.defaults['eps'])
8888

8989
return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

pytorch_optimizer/optimizer/alig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def reset(self):
5252
@torch.no_grad()
5353
def compute_step_size(self, loss: float) -> float:
5454
r"""Compute step_size."""
55-
global_grad_norm = get_global_gradient_norm(self.param_groups, torch.device('cpu'))
55+
global_grad_norm = get_global_gradient_norm(self.param_groups)
5656
global_grad_norm.add_(1e-6)
5757

5858
return loss / global_grad_norm.item()

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
473473
numerator_weighted = group['numerator_weighted']
474474

475475
if group['step'] == 0:
476-
group['g0_norm'] = get_global_gradient_norm(self.param_groups, device).sqrt_().item()
476+
group['g0_norm'] = get_global_gradient_norm(self.param_groups).sqrt_().item()
477477
g0_norm = group['g0_norm']
478478

479479
if g0_norm == 0:

pytorch_optimizer/optimizer/lamb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
103103
if self.defaults['max_grad_norm'] == 0.0:
104104
return 1.0
105105

106-
global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
106+
global_grad_norm = get_global_gradient_norm(self.param_groups)
107107
global_grad_norm.sqrt_().add_(self.defaults['eps'])
108108

109109
return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

pytorch_optimizer/optimizer/prodigy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
110110

111111
if 'd_numerator' not in group:
112112
group['d_numerator'] = torch.tensor([0.0], device=device)
113+
elif group['d_numerator'].device != device:
114+
group['d_numerator'] = group['d_numerator'].to(device)
113115

114116
d_numerator = group['d_numerator']
115117
d_numerator.mul_(beta3)

pytorch_optimizer/optimizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def l2_projection(parameters: PARAMETERS, max_norm: float = 1e2):
272272

273273

274274
@torch.no_grad()
275-
def get_global_gradient_norm(param_groups: List[Dict], device: torch.device) -> torch.Tensor:
275+
def get_global_gradient_norm(param_groups: List[Dict]) -> torch.Tensor:
276276
r"""Get global gradient norm."""
277-
global_grad_norm = torch.zeros(1, dtype=torch.float32, device=device)
277+
global_grad_norm = torch.zeros(1, dtype=torch.float32, device=param_groups[0]['params'][0].device)
278278

279279
for group in param_groups:
280280
for p in group['params']:

tests/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@
428428
(Prodigy, {'lr': 1e0, 'beta3': 0.999, 'weight_decay': 1e-3, 'safeguard_warmup': True}, 15),
429429
(PAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
430430
(Tiger, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
431-
(CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3}, 75),
432-
(CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 75),
431+
(CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3}, 70),
432+
(CAME, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 70),
433433
(Aida, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 5),
434434
(
435435
GaLore,

0 commit comments

Comments
 (0)