11import math
2+ from typing import Union
23
34import torch
45from torch .optim .optimizer import Optimizer
@@ -17,6 +18,7 @@ class Adan(Optimizer, BaseOptimizer):
1718 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1819 :param weight_decay: float. weight decay (L2 penalty).
1920 :param weight_decouple: bool. decoupled weight decay.
21+ :param max_grad_norm: float. max gradient norm to clip.
2022 :param use_gc: bool. use gradient centralization.
2123 :param eps: float. term added to the denominator to improve numerical stability.
2224 """
@@ -28,13 +30,15 @@ def __init__(
2830 betas : BETAS = (0.98 , 0.92 , 0.99 ),
2931 weight_decay : float = 0.0 ,
3032 weight_decouple : bool = False ,
33+ max_grad_norm : float = 0.0 ,
3134 use_gc : bool = False ,
3235 eps : float = 1e-8 ,
3336 ):
3437 self .lr = lr
3538 self .betas = betas
3639 self .weight_decay = weight_decay
3740 self .weight_decouple = weight_decouple
41+ self .max_grad_norm = max_grad_norm
3842 self .use_gc = use_gc
3943 self .eps = eps
4044
@@ -46,6 +50,7 @@ def __init__(
4650 eps = eps ,
4751 weight_decay = weight_decay ,
4852 weight_decouple = weight_decouple ,
53+ max_grad_norm = max_grad_norm ,
4954 )
5055 super ().__init__ (params , defaults )
5156
@@ -54,6 +59,7 @@ def validate_parameters(self):
5459 self .validate_betas (self .betas )
5560 self .validate_weight_decay (self .weight_decay )
5661 self .validate_epsilon (self .eps )
62+ self .validate_norm (self .max_grad_norm )
5763
5864 @property
5965 def __name__ (self ) -> str :
@@ -62,23 +68,54 @@ def __name__(self) -> str:
6268 @torch .no_grad ()
6369 def reset (self ):
6470 for group in self .param_groups :
71+ group ['step' ] = 0
6572 for p in group ['params' ]:
6673 state = self .state [p ]
6774
68- state ['step' ] = 0
6975 state ['exp_avg' ] = torch .zeros_like (p )
7076 state ['exp_avg_diff' ] = torch .zeros_like (p )
7177 state ['exp_avg_nest' ] = torch .zeros_like (p )
7278 state ['previous_grad' ] = torch .zeros_like (p )
7379
80+ @torch .no_grad ()
81+ def get_global_gradient_norm (self ) -> Union [torch .Tensor , float ]:
82+ if self .defaults ['max_grad_norm' ] == 0.0 :
83+ return 1.0
84+
85+ device = self .param_groups [0 ]['params' ][0 ].device
86+
87+ global_grad_norm = torch .zeros (1 , device = device )
88+ max_grad_norm = torch .tensor (self .defaults ['max_grad_norm' ], device = device )
89+
90+ for group in self .param_groups :
91+ for p in group ['params' ]:
92+ if p .grad is not None :
93+ global_grad_norm .add_ (torch .linalg .norm (p .grad ).pow (2 ))
94+
95+ global_grad_norm = torch .sqrt (global_grad_norm )
96+
97+ return torch .clamp (max_grad_norm / (global_grad_norm + self .eps ), max = 1.0 )
98+
7499 @torch .no_grad ()
75100 def step (self , closure : CLOSURE = None ) -> LOSS :
76101 loss : LOSS = None
77102 if closure is not None :
78103 with torch .enable_grad ():
79104 loss = closure ()
80105
106+ clip_global_grad_norm = self .get_global_gradient_norm ()
107+
81108 for group in self .param_groups :
109+ if 'step' in group :
110+ group ['step' ] += 1
111+ else :
112+ group ['step' ] = 1
113+
114+ beta1 , beta2 , beta3 = group ['betas' ]
115+ bias_correction1 = 1.0 - beta1 ** group ['step' ]
116+ bias_correction2 = 1.0 - beta2 ** group ['step' ]
117+ bias_correction3_sq = math .sqrt (1.0 - beta3 ** group ['step' ])
118+
82119 for p in group ['params' ]:
83120 if p .grad is None :
84121 continue
@@ -89,35 +126,28 @@ def step(self, closure: CLOSURE = None) -> LOSS:
89126
90127 state = self .state [p ]
91128 if len (state ) == 0 :
92- state ['step' ] = 0
93129 state ['exp_avg' ] = torch .zeros_like (p )
94130 state ['exp_avg_diff' ] = torch .zeros_like (p )
95131 state ['exp_avg_nest' ] = torch .zeros_like (p )
96- state ['previous_grad' ] = torch . zeros_like ( p )
132+ state ['previous_grad' ] = grad . clone ( )
97133
98134 exp_avg , exp_avg_diff , exp_avg_nest = state ['exp_avg' ], state ['exp_avg_diff' ], state ['exp_avg_nest' ]
99- prev_grad = state ['previous_grad' ]
100-
101- state ['step' ] += 1
102- beta1 , beta2 , beta3 = group ['betas' ]
103135
104- bias_correction1 = 1.0 - beta1 ** state ['step' ]
105- bias_correction2 = 1.0 - beta2 ** state ['step' ]
106- bias_correction3 = 1.0 - beta3 ** state ['step' ]
136+ grad .mul_ (clip_global_grad_norm )
107137
108138 if self .use_gc :
109139 grad = centralize_gradient (grad , gc_conv_only = False )
110140
111- grad_diff = grad - prev_grad
112- state ['previous_grad' ] = grad . clone ( )
141+ grad_diff = grad - state [ 'previous_grad' ]
142+ state ['previous_grad' ]. copy_ ( grad )
113143
114144 update = grad + beta2 * grad_diff
115145
116146 exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
117147 exp_avg_diff .mul_ (beta2 ).add_ (grad_diff , alpha = 1.0 - beta2 )
118148 exp_avg_nest .mul_ (beta3 ).addcmul_ (update , update , value = 1.0 - beta3 )
119149
120- de_nom = (exp_avg_nest .sqrt_ () / math . sqrt ( bias_correction3 ) ).add_ (self .eps )
150+ de_nom = (exp_avg_nest .sqrt_ () / bias_correction3_sq ).add_ (self .eps )
121151 perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2 ).div_ (de_nom )
122152
123153 if group ['weight_decouple' ]:
0 commit comments