33import torch
44from torch .optim .optimizer import Optimizer
55
6+ from pytorch_optimizer .base_optimizer import BaseOptimizer
67from pytorch_optimizer .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS , STATE
78
89
9- class AdaBelief (Optimizer ):
10+ class AdaBelief (Optimizer , BaseOptimizer ):
1011 """
1112 Reference : https://github.com/juntang-zhuang/Adabelief-Optimizer
1213 Example :
@@ -37,7 +38,7 @@ def __init__(
3738 adamd_debias_term : bool = False ,
3839 eps : float = 1e-16 ,
3940 ):
40- """AdaBelief
41+ """AdaBelief optimizer
4142 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4243 :param lr: float. learning rate
4344 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
@@ -62,7 +63,7 @@ def __init__(
6263 self .adamd_debias_term = adamd_debias_term
6364 self .eps = eps
6465
65- self .check_valid_parameters ()
66+ self .validate_parameters ()
6667
6768 defaults : DEFAULTS = dict (
6869 lr = lr ,
@@ -75,17 +76,11 @@ def __init__(
7576 )
7677 super ().__init__ (params , defaults )
7778
78- def check_valid_parameters (self ):
79- if self .lr < 0.0 :
80- raise ValueError (f'Invalid learning rate : { self .lr } ' )
81- if not 0.0 <= self .betas [0 ] < 1.0 :
82- raise ValueError (f'Invalid beta_0 : { self .betas [0 ]} ' )
83- if not 0.0 <= self .betas [1 ] < 1.0 :
84- raise ValueError (f'Invalid beta_1 : { self .betas [1 ]} ' )
85- if self .weight_decay < 0.0 :
86- raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
87- if self .eps < 0.0 :
88- raise ValueError (f'Invalid eps : { self .eps } ' )
79+ def validate_parameters (self ):
80+ self .validate_learning_rate (self .lr )
81+ self .validate_betas (self .betas )
82+ self .validate_weight_decay (self .weight_decay )
83+ self .validate_epsilon (self .eps )
8984
9085 def __setstate__ (self , state : STATE ):
9186 super ().__setstate__ (state )
@@ -125,7 +120,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
125120 grad = grad .float ()
126121
127122 p_fp32 = p
128- if p .dtype in { torch .float16 , torch .bfloat16 } :
123+ if p .dtype in ( torch .float16 , torch .bfloat16 ) :
129124 p_fp32 = p_fp32 .float ()
130125
131126 state = self .state [p ]
@@ -158,14 +153,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
158153 exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1.0 - beta2 )
159154
160155 if group ['amsgrad' ]:
161- max_exp_avg_var = state ['max_exp_avg_var' ]
162-
163- torch .max (
164- max_exp_avg_var ,
165- exp_avg_var .add_ (group ['eps' ]),
166- out = max_exp_avg_var ,
167- )
168-
156+ max_exp_avg_var = torch .max (state ['max_exp_avg_var' ], exp_avg_var .add_ (group ['eps' ]))
169157 de_nom = (max_exp_avg_var .sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
170158 else :
171159 de_nom = (exp_avg_var .add_ (group ['eps' ]).sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
@@ -176,7 +164,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176164 step_size /= bias_correction1
177165 p_fp32 .addcdiv_ (exp_avg , de_nom , value = - step_size )
178166 else :
179- buffered = group ['buffer' ][int ( state ['step' ] % 10 ) ]
167+ buffered = group ['buffer' ][state ['step' ] % 10 ]
180168 if state ['step' ] == buffered [0 ]:
181169 n_sma , step_size = buffered [1 ], buffered [2 ]
182170 else :
@@ -213,7 +201,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
213201 elif step_size > 0 :
214202 p_fp32 .add_ (exp_avg , alpha = - step_size * group ['lr' ])
215203
216- if p .dtype in { torch .float16 , torch .bfloat16 } :
204+ if p .dtype in ( torch .float16 , torch .bfloat16 ) :
217205 p .copy_ (p_fp32 )
218206
219207 return loss
0 commit comments