88from pytorch_optimizer .base .optimizer import BaseOptimizer
99from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
1010from pytorch_optimizer .optimizer .gc import centralize_gradient
11+ from pytorch_optimizer .optimizer .utils import get_global_gradient_norm
1112
1213
1314class Adan (Optimizer , BaseOptimizer ):
@@ -20,6 +21,8 @@ class Adan(Optimizer, BaseOptimizer):
2021 :param weight_decouple: bool. decoupled weight decay.
2122 :param max_grad_norm: float. max gradient norm to clip.
2223 :param use_gc: bool. use gradient centralization.
24+ :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
25+ :param adanorm: bool. whether to use the AdaNorm variant.
2326 :param eps: float. term added to the denominator to improve numerical stability.
2427 """
2528
@@ -32,6 +35,8 @@ def __init__(
3235 weight_decouple : bool = False ,
3336 max_grad_norm : float = 0.0 ,
3437 use_gc : bool = False ,
38+ r : float = 0.95 ,
39+ adanorm : bool = False ,
3540 eps : float = 1e-8 ,
3641 ):
3742 self .lr = lr
@@ -49,8 +54,12 @@ def __init__(
4954 'weight_decay' : weight_decay ,
5055 'weight_decouple' : weight_decouple ,
5156 'max_grad_norm' : max_grad_norm ,
57+ 'adanorm' : adanorm ,
5258 'eps' : eps ,
5359 }
60+ if adanorm :
61+ defaults .update ({'r' : r })
62+
5463 super ().__init__ (params , defaults )
5564
5665 def validate_parameters (self ):
@@ -71,25 +80,21 @@ def reset(self):
7180 state = self .state [p ]
7281
7382 state ['exp_avg' ] = torch .zeros_like (p )
83+ state ['exp_avg_sq' ] = torch .zeros_like (p )
7484 state ['exp_avg_diff' ] = torch .zeros_like (p )
75- state ['exp_avg_nest' ] = torch .zeros_like (p )
7685 state ['previous_grad' ] = torch .zeros_like (p )
86+ if group ['adanorm' ]:
87+ state ['exp_grad_norm' ] = torch .zeros ((1 ,), dtype = p .dtype , device = p .device )
7788
7889 @torch .no_grad ()
7990 def get_global_gradient_norm (self ) -> Union [torch .Tensor , float ]:
8091 if self .defaults ['max_grad_norm' ] == 0.0 :
8192 return 1.0
8293
83- global_grad_norm = torch .zeros (1 , dtype = torch .float32 , device = self .param_groups [0 ]['params' ][0 ].device )
84-
85- for group in self .param_groups :
86- for p in group ['params' ]:
87- if p .grad is not None :
88- global_grad_norm .add_ (torch .linalg .norm (p .grad ).pow (2 ))
89-
90- global_grad_norm .sqrt_ ()
94+ global_grad_norm = get_global_gradient_norm (self .param_groups , self .param_groups [0 ]['params' ][0 ].device )
95+ global_grad_norm .sqrt_ ().add_ (self .eps )
9196
92- return torch .clamp (self .defaults ['max_grad_norm' ] / ( global_grad_norm + self . eps ) , max = 1.0 )
97+ return torch .clamp (self .defaults ['max_grad_norm' ] / global_grad_norm , max = 1.0 )
9398
9499 @torch .no_grad ()
95100 def step (self , closure : CLOSURE = None ) -> LOSS :
@@ -122,35 +127,50 @@ def step(self, closure: CLOSURE = None) -> LOSS:
122127 state = self .state [p ]
123128 if len (state ) == 0 :
124129 state ['exp_avg' ] = torch .zeros_like (p )
130+ state ['exp_avg_sq' ] = torch .zeros_like (p )
125131 state ['exp_avg_diff' ] = torch .zeros_like (p )
126- state ['exp_avg_nest' ] = torch .zeros_like (p )
127- state ['previous_grad' ] = grad .clone ()
132+ state ['previous_grad' ] = grad .clone ().mul_ (- clip_global_grad_norm )
133+ if group ['adanorm' ]:
134+ state ['exp_grad_norm' ] = torch .zeros ((1 ,), dtype = grad .dtype , device = grad .device )
128135
129136 grad .mul_ (clip_global_grad_norm )
130137
131138 if self .use_gc :
132139 grad = centralize_gradient (grad , gc_conv_only = False )
133140
134- grad_diff = - state ['previous_grad' ]
141+ grad_diff = state ['previous_grad' ]
135142 grad_diff .add_ (grad )
136- state ['previous_grad' ].copy_ (grad )
137143
138- update = grad + beta2 * grad_diff
144+ s_grad = grad
145+ if group ['adanorm' ]:
146+ grad_norm = torch .linalg .norm (grad )
147+
148+ exp_grad_norm = state ['exp_grad_norm' ]
149+ exp_grad_norm .mul_ (group ['r' ]).add_ (grad_norm , alpha = 1.0 - group ['r' ])
139150
140- exp_avg , exp_avg_diff , exp_avg_nest = state ['exp_avg' ], state ['exp_avg_diff' ], state ['exp_avg_nest' ]
151+ if exp_grad_norm > grad_norm :
152+ s_grad *= exp_grad_norm / grad_norm
141153
142- exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
154+ exp_avg , exp_avg_sq , exp_avg_diff = state ['exp_avg' ], state ['exp_avg_sq' ], state ['exp_avg_diff' ]
155+
156+ exp_avg .mul_ (beta1 ).add_ (s_grad , alpha = 1.0 - beta1 )
143157 exp_avg_diff .mul_ (beta2 ).add_ (grad_diff , alpha = 1.0 - beta2 )
144- exp_avg_nest .mul_ (beta3 ).addcmul_ (update , update , value = 1.0 - beta3 )
145158
146- de_nom = (exp_avg_nest .sqrt_ () / bias_correction3_sq ).add_ (self .eps )
147- perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2 ).div_ (de_nom )
159+ grad_diff .mul_ (beta2 ).add_ (grad )
160+ exp_avg_sq .mul_ (beta3 ).addcmul_ (grad_diff , grad_diff , value = 1.0 - beta3 )
161+
162+ de_nom = exp_avg_sq .sqrt ()
163+ de_nom .div_ (bias_correction3_sq ).add_ (group ['eps' ])
148164
149165 if group ['weight_decouple' ]:
150166 p .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
151- p .add_ (perturb , alpha = - group ['lr' ])
152- else :
153- p .add_ (perturb , alpha = - group ['lr' ])
167+
168+ p .addcdiv_ (exp_avg , de_nom , value = - group ['lr' ] / bias_correction1 )
169+ p .addcdiv_ (exp_avg_diff , de_nom , value = - group ['lr' ] * beta2 / bias_correction2 )
170+
171+ if not group ['weight_decouple' ]:
154172 p .div_ (1.0 + group ['lr' ] * group ['weight_decay' ])
155173
174+ state ['previous_grad' ].copy_ (- grad )
175+
156176 return loss
0 commit comments