1313)
1414
1515
16+ class AdaBound (Optimizer ):
17+ """
18+ Reference : https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
19+ Example :
20+ from pytorch_optimizer import AdaBound
21+ ...
22+ model = YourModel()
23+ optimizer = AdaBound(model.parameters())
24+ ...
25+ for input, output in data:
26+ optimizer.zero_grad()
27+ loss = loss_function(output, model(input))
28+ loss.backward()
29+ optimizer.step()
30+ """
31+
32+ def __init__ (
33+ self ,
34+ params : PARAMS ,
35+ lr : float = 1e-3 ,
36+ betas : BETAS = (0.9 , 0.999 ),
37+ final_lr : float = 0.1 ,
38+ gamma : float = 1e-3 ,
39+ eps : float = 1e-8 ,
40+ weight_decay : float = 0.0 ,
41+ amsbound : bool = False ,
42+ ):
43+ """AdaBound optimizer
44+ :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
45+ :param lr: float. learning rate
46+ :param final_lr: float. final learning rate
47+ :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
48+ :param gamma: float. convergence speed of the bound functions
49+ :param eps: float. term added to the denominator to improve numerical stability
50+ :param weight_decay: float. weight decay (L2 penalty)
51+ :param amsbound: bool. whether to use the AMSBound variant
52+ """
53+ self .lr = lr
54+ self .betas = betas
55+ self .eps = eps
56+ self .weight_decay = weight_decay
57+
58+ defaults : DEFAULT_PARAMETERS = dict (
59+ lr = lr ,
60+ betas = betas ,
61+ final_lr = final_lr ,
62+ gamma = gamma ,
63+ eps = eps ,
64+ weight_decay = weight_decay ,
65+ amsbound = amsbound ,
66+ )
67+ super ().__init__ (params , defaults )
68+
69+ self .base_lrs = [group ['lr' ] for group in self .param_groups ]
70+
71+ def check_valid_parameters (self ):
72+ if 0.0 > self .lr :
73+ raise ValueError (f'Invalid learning rate : { self .lr } ' )
74+ if 0.0 > self .eps :
75+ raise ValueError (f'Invalid eps : { self .eps } ' )
76+ if 0.0 > self .weight_decay :
77+ raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
78+ if not 0.0 <= self .betas [0 ] < 1.0 :
79+ raise ValueError (f'Invalid beta_0 : { self .betas [0 ]} ' )
80+ if not 0.0 <= self .betas [1 ] < 1.0 :
81+ raise ValueError (f'Invalid beta_1 : { self .betas [1 ]} ' )
82+
83+ def __setstate__ (self , state : STATE ):
84+ super ().__setstate__ (state )
85+ for group in self .param_groups :
86+ group .setdefault ('amsbound' , False )
87+
88+ def step (self , closure : CLOSURE = None ) -> LOSS :
89+ loss : LOSS = None
90+ if closure is not None :
91+ loss = closure ()
92+
93+ for group , base_lr in zip (self .param_groups , self .base_lrs ):
94+ for p in group ['params' ]:
95+ if p .grad is None :
96+ continue
97+
98+ grad = p .grad .data
99+ if grad .is_sparse :
100+ raise RuntimeError (
101+ 'AdaBound does not support sparse gradients'
102+ )
103+
104+ amsbound = group ['amsbound' ]
105+
106+ state = self .state [p ]
107+
108+ if len (state ) == 0 :
109+ state ['step' ] = 0
110+ state ['exp_avg' ] = torch .zeros_like (p )
111+ state ['exp_avg_sq' ] = torch .zeros_like (p )
112+ if amsbound :
113+ state ['max_exp_avg_sq' ] = torch .zeros_like (p )
114+
115+ exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
116+ if amsbound :
117+ max_exp_avg_sq = state ['max_exp_avg_sq' ]
118+ beta1 , beta2 = group ['betas' ]
119+
120+ state ['step' ] += 1
121+
122+ if group ['weight_decay' ] != 0 :
123+ grad = grad .add (group ['weight_decay' ], p .data )
124+
125+ # Decay the first and second moment running average coefficient
126+ exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
127+ exp_avg_sq .mul_ (beta2 ).addcmul_ (1 - beta2 , grad , grad )
128+ if amsbound :
129+ torch .max (max_exp_avg_sq , exp_avg_sq , out = max_exp_avg_sq )
130+ denom = max_exp_avg_sq .sqrt ().add_ (group ['eps' ])
131+ else :
132+ denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
133+
134+ bias_correction1 = 1 - beta1 ** state ['step' ]
135+ bias_correction2 = 1 - beta2 ** state ['step' ]
136+ step_size = (
137+ group ['lr' ]
138+ * math .sqrt (bias_correction2 )
139+ / bias_correction1
140+ )
141+
142+ final_lr = group ['final_lr' ] * group ['lr' ] / base_lr
143+ lower_bound = final_lr * (
144+ 1 - 1 / (group ['gamma' ] * state ['step' ] + 1 )
145+ )
146+ upper_bound = final_lr * (
147+ 1 + 1 / (group ['gamma' ] * state ['step' ])
148+ )
149+ step_size = torch .full_like (denom , step_size )
150+ step_size .div_ (denom ).clamp_ (lower_bound , upper_bound ).mul_ (
151+ exp_avg
152+ )
153+
154+ p .data .add_ (- step_size )
155+
156+ return loss
157+
158+
16159class AdaBoundW (Optimizer ):
17160 """
18161 Reference : https://github.com/Luolc/AdaBound
@@ -50,6 +193,11 @@ def __init__(
50193 :param weight_decay: float. weight decay (L2 penalty)
51194 :param amsbound: bool. whether to use the AMSBound variant
52195 """
196+ self .lr = lr
197+ self .betas = betas
198+ self .eps = eps
199+ self .weight_decay = weight_decay
200+
53201 defaults : DEFAULT_PARAMETERS = dict (
54202 lr = lr ,
55203 betas = betas ,
@@ -63,6 +211,18 @@ def __init__(
63211
64212 self .base_lrs = [group ['lr' ] for group in self .param_groups ]
65213
214+ def check_valid_parameters (self ):
215+ if 0.0 > self .lr :
216+ raise ValueError (f'Invalid learning rate : { self .lr } ' )
217+ if 0.0 > self .eps :
218+ raise ValueError (f'Invalid eps : { self .eps } ' )
219+ if 0.0 > self .weight_decay :
220+ raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
221+ if not 0.0 <= self .betas [0 ] < 1.0 :
222+ raise ValueError (f'Invalid beta_0 : { self .betas [0 ]} ' )
223+ if not 0.0 <= self .betas [1 ] < 1.0 :
224+ raise ValueError (f'Invalid beta_1 : { self .betas [1 ]} ' )
225+
66226 def __setstate__ (self , state : STATE ):
67227 super ().__setstate__ (state )
68228 for group in self .param_groups :
0 commit comments