Skip to content

Commit 0386506

Browse files
authored
Merge pull request #347 from kozistr/fix/adopt-optimizer
[Fix] Updating `exp_avg_sq` after calculating the `denominator` in `ADOPT` optimizer
2 parents b82f7c4 + db3a9ab commit 0386506

File tree

4 files changed

+16
-7
lines changed

4 files changed

+16
-7
lines changed

docs/changelogs/v3.4.1.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
* `AGC` + `Lookahead` variants
1616
* change default beta1, beta2 to 0.95 and 0.98 respectively
1717
* Skip adding `Lookahead` wrapper in case of `Ranger*` optimizers, which already have it in `create_optimizer()`. (#340)
18+
* Improved optimizer visualization. (#345)
19+
20+
### Bug
21+
22+
* Fix to update exp_avg_sq after calculating the denominator in `ADOPT` optimizer. (#346, #347)
1823

1924
### Docs
2025

2126
* Update the visualizations. (#340)
27+
28+
### Contributions
29+
30+
thanks to @AidinHamedi

pytorch_optimizer/optimizer/adopt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
115115
exp_avg_sq.addcmul_(grad, grad.conj())
116116
continue
117117

118-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)
119-
120118
de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])
121119

122120
normed_grad = grad.div(de_nom)
@@ -137,4 +135,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137135

138136
p.add_(update, alpha=-lr)
139137

138+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)
139+
140140
return loss

pytorch_optimizer/optimizer/experimental/ranger25.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
183183
fixed_decay=group['fixed_decay'],
184184
)
185185

186-
exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']
187-
188186
grad.copy_(agc(p, grad))
189187

188+
exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']
189+
190190
normed_grad = grad.div(
191191
exp_avg_sq.sqrt().clamp_(min=group['eps'] if group['eps'] is not None else 1e-8)
192192
).clamp_(-clip, clip)

tests/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,9 @@
563563
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
564564
(Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3),
565565
(EXAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
566-
(Ranger25, {'lr': 5e-2}, 5),
567-
(Ranger25, {'lr': 5e-2, 't_alpha_beta3': 5}, 5),
568-
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None}, 5),
566+
(Ranger25, {'lr': 1e-1}, 3),
567+
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),
568+
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None, 'lookahead_merge_time': 2}, 3),
569569
]
570570
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
571571
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

0 commit comments

Comments
 (0)