33import torch
44from torch .optim .optimizer import Optimizer
55
6- from pytorch_optimizer .types import (
7- BETAS ,
8- CLOSURE ,
9- DEFAULT_PARAMETERS ,
10- LOSS ,
11- PARAMS ,
12- STATE ,
13- )
6+ from pytorch_optimizer .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS , STATE
147
158
169class AdaBelief (Optimizer ):
@@ -31,60 +24,47 @@ class AdaBelief(Optimizer):
3124
3225 def __init__ (
3326 self ,
34- params : PARAMS ,
27+ params : PARAMETERS ,
3528 lr : float = 1e-3 ,
3629 betas : BETAS = (0.9 , 0.999 ),
37- eps : float = 1e-16 ,
3830 weight_decay : float = 0.0 ,
3931 n_sma_threshold : int = 5 ,
40- amsgrad : bool = False ,
4132 weight_decouple : bool = True ,
4233 fixed_decay : bool = False ,
4334 rectify : bool = True ,
4435 degenerated_to_sgd : bool = True ,
36+ amsgrad : bool = False ,
37+ eps : float = 1e-16 ,
4538 ):
46- """AdaBelief optimizer
47- :param params: PARAMS. iterable of parameters to optimize
48- or dicts defining parameter groups
39+ """
40+ :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4941 :param lr: float. learning rate
50- :param betas: BETAS. coefficients used for computing running averages
51- of gradient and the squared hessian trace
52- :param eps: float. term added to the denominator
53- to improve numerical stability
42+ :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
5443 :param weight_decay: float. weight decay (L2 penalty)
5544 :param n_sma_threshold: (recommended is 5)
56- :param amsgrad: bool. whether to use the AMSBound variant
57- :param weight_decouple: bool. the optimizer uses decoupled weight decay
58- as in AdamW
45+ :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
5946 :param fixed_decay: bool.
6047 :param rectify: bool. perform the rectified update similar to RAdam
61- :param degenerated_to_sgd: bool. perform SGD update
62- when variance of gradient is high
48+ :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
49+ :param amsgrad: bool. whether to use the AMSBound variant
50+ :param eps: float. term added to the denominator to improve numerical stability
6351 """
6452 self .lr = lr
6553 self .betas = betas
66- self .eps = eps
6754 self .weight_decay = weight_decay
6855 self .n_sma_threshold = n_sma_threshold
69- self .degenerated_to_sgd = degenerated_to_sgd
7056 self .weight_decouple = weight_decouple
71- self .rectify = rectify
7257 self .fixed_decay = fixed_decay
58+ self .rectify = rectify
7359 self .degenerated_to_sgd = degenerated_to_sgd
60+ self .eps = eps
7461
75- if (
76- isinstance (params , (list , tuple ))
77- and len (params ) > 0
78- and isinstance (params [0 ], dict )
79- ):
62+ if isinstance (params , (list , tuple )) and len (params ) > 0 and isinstance (params [0 ], dict ):
8063 for param in params :
81- if 'betas' in param and (
82- param ['betas' ][0 ] != betas [0 ]
83- or param ['betas' ][1 ] != betas [1 ]
84- ):
64+ if 'betas' in param and (param ['betas' ][0 ] != betas [0 ] or param ['betas' ][1 ] != betas [1 ]):
8565 param ['buffer' ] = [[None , None , None ] for _ in range (10 )]
8666
87- defaults : DEFAULT_PARAMETERS = dict (
67+ defaults : DEFAULTS = dict (
8868 lr = lr ,
8969 betas = betas ,
9070 eps = eps ,
@@ -129,9 +109,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
129109
130110 grad = p .grad .data
131111 if grad .is_sparse :
132- raise RuntimeError (
133- 'AdaBelief does not support sparse gradients'
134- )
112+ raise RuntimeError ('AdaBelief does not support sparse gradients' )
135113
136114 amsgrad = group ['amsgrad' ]
137115
@@ -163,9 +141,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
163141
164142 exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 )
165143 grad_residual = grad - exp_avg
166- exp_avg_var .mul_ (beta2 ).addcmul_ (
167- grad_residual , grad_residual , value = 1 - beta2
168- )
144+ exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1 - beta2 )
169145
170146 if amsgrad :
171147 max_exp_avg_var = state ['max_exp_avg_var' ]
@@ -176,14 +152,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176152 out = max_exp_avg_var ,
177153 )
178154
179- denom = (
180- max_exp_avg_var .sqrt () / math .sqrt (bias_correction2 )
181- ).add_ (group ['eps' ])
155+ denom = (max_exp_avg_var .sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
182156 else :
183- denom = (
184- exp_avg_var .add_ (group ['eps' ]).sqrt ()
185- / math .sqrt (bias_correction2 )
186- ).add_ (group ['eps' ])
157+ denom = (exp_avg_var .add_ (group ['eps' ]).sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
187158
188159 if not self .rectify :
189160 step_size = group ['lr' ] / bias_correction1
@@ -196,9 +167,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
196167 buffered [0 ] = state ['step' ]
197168 beta2_t = beta2 ** state ['step' ]
198169 n_sma_max = 2 / (1 - beta2 ) - 1
199- n_sma = n_sma_max - 2 * state ['step' ] * beta2_t / (
200- 1 - beta2_t
201- )
170+ n_sma = n_sma_max - 2 * state ['step' ] * beta2_t / (1 - beta2_t )
202171 buffered [1 ] = n_sma
203172
204173 if n_sma >= self .n_sma_threshold :
@@ -219,9 +188,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
219188
220189 if n_sma >= self .n_sma_threshold :
221190 denom = exp_avg_var .sqrt ().add_ (group ['eps' ])
222- p .data .addcdiv_ (
223- exp_avg , denom , value = - step_size * group ['lr' ]
224- )
191+ p .data .addcdiv_ (exp_avg , denom , value = - step_size * group ['lr' ])
225192 elif step_size > 0 :
226193 p .data .add_ (exp_avg , alpha = - step_size * group ['lr' ])
227194
0 commit comments