@@ -17,6 +17,7 @@ class ADOPT(BaseOptimizer):
1717 :param weight_decay: float. weight decay (L2 penalty).
1818 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
1919 :param fixed_decay: bool. fix weight decay.
20+ :param cautious: bool. whether to use the Cautious variant.
2021 :param eps: float. term added to the denominator to improve numerical stability.
2122 """
2223
@@ -29,6 +30,7 @@ def __init__(
2930 weight_decay : float = 0.0 ,
3031 weight_decouple : bool = False ,
3132 fixed_decay : bool = False ,
33+ cautious : bool = False ,
3234 eps : float = 1e-6 ,
3335 ** kwargs ,
3436 ):
@@ -38,6 +40,7 @@ def __init__(
3840 self .validate_non_negative (eps , 'eps' )
3941
4042 self .clip_lambda = clip_lambda
43+ self .cautious = cautious
4144
4245 defaults : DEFAULTS = {
4346 'lr' : lr ,
@@ -118,6 +121,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
118121
119122 exp_avg .lerp_ (normed_grad , weight = 1.0 - beta1 )
120123
121- p .add_ (exp_avg , alpha = - group ['lr' ])
124+ if self .cautious :
125+ update = exp_avg .clone ()
126+ self .apply_cautious (update , normed_grad )
127+ else :
128+ update = exp_avg
129+
130+ p .add_ (update , alpha = - group ['lr' ])
122131
123132 return loss
0 commit comments