1+ import math
2+
13import torch
24from torch .optim .optimizer import Optimizer
35
810
911class Adan (Optimizer , BaseOptimizer ):
1012 """
11- Reference : x
13+ Reference : https://github.com/sail-sg/Adan/blob/main/adan.py
1214 Example :
1315 from pytorch_optimizer import Adan
1416 ...
@@ -27,21 +29,24 @@ def __init__(
2729 params : PARAMETERS ,
2830 lr : float = 1e-3 ,
2931 betas : BETAS = (0.98 , 0.92 , 0.99 ),
30- weight_decay : float = 0.02 ,
32+ weight_decay : float = 0.0 ,
33+ weight_decouple : bool = False ,
3134 use_gc : bool = False ,
32- eps : float = 1e-16 ,
35+ eps : float = 1e-8 ,
3336 ):
3437 """Adan optimizer
3538 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3639 :param lr: float. learning rate
3740 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
3841 :param weight_decay: float. weight decay (L2 penalty)
42+ :param weight_decouple: bool. decoupled weight decay
3943 :param use_gc: bool. use gradient centralization
4044 :param eps: float. term added to the denominator to improve numerical stability
4145 """
4246 self .lr = lr
4347 self .betas = betas
4448 self .weight_decay = weight_decay
49+ self .weight_decouple = weight_decouple
4550 self .use_gc = use_gc
4651 self .eps = eps
4752
@@ -52,6 +57,7 @@ def __init__(
5257 betas = betas ,
5358 eps = eps ,
5459 weight_decay = weight_decay ,
60+ weight_decouple = weight_decouple ,
5561 )
5662 super ().__init__ (params , defaults )
5763
@@ -69,7 +75,7 @@ def reset(self):
6975
7076 state ['step' ] = 0
7177 state ['exp_avg' ] = torch .zeros_like (p )
72- state ['exp_avg_var ' ] = torch .zeros_like (p )
78+ state ['exp_avg_diff ' ] = torch .zeros_like (p )
7379 state ['exp_avg_nest' ] = torch .zeros_like (p )
7480 state ['previous_grad' ] = torch .zeros_like (p )
7581
@@ -93,29 +99,40 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9399 if len (state ) == 0 :
94100 state ['step' ] = 0
95101 state ['exp_avg' ] = torch .zeros_like (p )
96- state ['exp_avg_var ' ] = torch .zeros_like (p )
102+ state ['exp_avg_diff ' ] = torch .zeros_like (p )
97103 state ['exp_avg_nest' ] = torch .zeros_like (p )
98104 state ['previous_grad' ] = torch .zeros_like (p )
99105
100- exp_avg , exp_avg_var , exp_avg_nest = state ['exp_avg' ], state ['exp_avg_var ' ], state ['exp_avg_nest' ]
106+ exp_avg , exp_avg_diff , exp_avg_nest = state ['exp_avg' ], state ['exp_avg_diff ' ], state ['exp_avg_nest' ]
101107 prev_grad = state ['previous_grad' ]
102108
103109 state ['step' ] += 1
104110 beta1 , beta2 , beta3 = group ['betas' ]
105111
112+ bias_correction1 = 1.0 - beta1 ** state ['step' ]
113+ bias_correction2 = 1.0 - beta2 ** state ['step' ]
114+ bias_correction3 = 1.0 - beta3 ** state ['step' ]
115+
106116 if self .use_gc :
107117 grad = centralize_gradient (grad , gc_conv_only = False )
108118
109119 grad_diff = grad - prev_grad
110120 state ['previous_grad' ] = grad .clone ()
111121
112- exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
113- exp_avg_var .mul_ (beta2 ).add_ (grad_diff , alpha = 1.0 - beta2 )
114- exp_avg_nest .mul_ (beta3 ).add_ ((grad + beta2 * grad_diff ) ** 2 , alpha = 1.0 - beta3 )
122+ update = grad + beta2 * grad_diff
115123
116- step_size = group ['lr' ] / exp_avg_nest .add_ (self .eps ).sqrt_ ()
117-
118- p .sub_ (step_size * (exp_avg + beta2 * exp_avg_var ))
119- p .div_ (1.0 + group ['weight_decay' ])
124+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
125+ exp_avg_diff .mul_ (beta2 ).add_ (grad_diff , alpha = 1.0 - beta2 )
126+ exp_avg_nest .mul_ (beta3 ).addcmul_ (update , update , value = 1.0 - beta3 )
127+
128+ de_nom = (exp_avg_nest .sqrt_ () / math .sqrt (bias_correction3 )).add_ (self .eps )
129+ perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2 ).div_ (de_nom )
130+
131+ if group ['weight_decouple' ]:
132+ p .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
133+ p .add_ (perturb , alpha = - group ['lr' ])
134+ else :
135+ p .add_ (perturb , alpha = - group ['lr' ])
136+ p .div_ (1.0 + group ['lr' ] * group ['weight_decay' ])
120137
121138 return loss
0 commit comments