Skip to content

Commit a628749

Browse files
committed
feature: support cautious variant
1 parent b573013 commit a628749

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

pytorch_optimizer/optimizer/adopt.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ADOPT(BaseOptimizer):
1717
:param weight_decay: float. weight decay (L2 penalty).
1818
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
1919
:param fixed_decay: bool. fix weight decay.
20+
:param cautious: bool. whether to use the Cautious variant.
2021
:param eps: float. term added to the denominator to improve numerical stability.
2122
"""
2223

@@ -29,6 +30,7 @@ def __init__(
2930
weight_decay: float = 0.0,
3031
weight_decouple: bool = False,
3132
fixed_decay: bool = False,
33+
cautious: bool = False,
3234
eps: float = 1e-6,
3335
**kwargs,
3436
):
@@ -38,6 +40,7 @@ def __init__(
3840
self.validate_non_negative(eps, 'eps')
3941

4042
self.clip_lambda = clip_lambda
43+
self.cautious = cautious
4144

4245
defaults: DEFAULTS = {
4346
'lr': lr,
@@ -118,6 +121,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
118121

119122
exp_avg.lerp_(normed_grad, weight=1.0 - beta1)
120123

121-
p.add_(exp_avg, alpha=-group['lr'])
124+
if self.cautious:
125+
update = exp_avg.clone()
126+
self.apply_cautious(update, normed_grad)
127+
else:
128+
update = exp_avg
129+
130+
p.add_(update, alpha=-group['lr'])
122131

123132
return loss

0 commit comments

Comments
 (0)