@@ -38,6 +38,8 @@ def __init__(
3838 gamma : float = 1e-3 ,
3939 eps : float = 1e-8 ,
4040 weight_decay : float = 0.0 ,
41+ weight_decouple : bool = True ,
42+ fixed_decay : bool = False ,
4143 amsbound : bool = False ,
4244 ):
4345 """AdaBound optimizer
@@ -46,6 +48,8 @@ def __init__(
4648 :param final_lr: float. final learning rate
4749 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4850 :param gamma: float. convergence speed of the bound functions
51+ :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
52+ :param fixed_decay: bool.
4953 :param eps: float. term added to the denominator to improve numerical stability
5054 :param weight_decay: float. weight decay (L2 penalty)
5155 :param amsbound: bool. whether to use the AMSBound variant
@@ -54,6 +58,8 @@ def __init__(
5458 self .betas = betas
5559 self .eps = eps
5660 self .weight_decay = weight_decay
61+ self .weight_decouple = weight_decouple
62+ self .fixed_decay = fixed_decay
5763
5864 defaults : DEFAULT_PARAMETERS = dict (
5965 lr = lr ,
@@ -119,150 +125,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
119125
120126 state ['step' ] += 1
121127
122- if group ['weight_decay' ] != 0 :
123- grad = grad .add (group ['weight_decay' ], p .data )
124-
125- # Decay the first and second moment running average coefficient
126- exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
127- exp_avg_sq .mul_ (beta2 ).addcmul_ (1 - beta2 , grad , grad )
128- if amsbound :
129- torch .max (max_exp_avg_sq , exp_avg_sq , out = max_exp_avg_sq )
130- denom = max_exp_avg_sq .sqrt ().add_ (group ['eps' ])
128+ if self .weight_decouple :
129+ if not self .fixed_decay :
130+ p .data .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
131+ else :
132+ p .data .mul_ (1.0 - group ['weight_decay' ])
131133 else :
132- denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
133-
134- bias_correction1 = 1 - beta1 ** state ['step' ]
135- bias_correction2 = 1 - beta2 ** state ['step' ]
136- step_size = (
137- group ['lr' ]
138- * math .sqrt (bias_correction2 )
139- / bias_correction1
140- )
141-
142- final_lr = group ['final_lr' ] * group ['lr' ] / base_lr
143- lower_bound = final_lr * (
144- 1 - 1 / (group ['gamma' ] * state ['step' ] + 1 )
145- )
146- upper_bound = final_lr * (
147- 1 + 1 / (group ['gamma' ] * state ['step' ])
148- )
149- step_size = torch .full_like (denom , step_size )
150- step_size .div_ (denom ).clamp_ (lower_bound , upper_bound ).mul_ (
151- exp_avg
152- )
153-
154- p .data .add_ (- step_size )
155-
156- return loss
157-
158-
159- class AdaBoundW (Optimizer ):
160- """
161- Reference : https://github.com/Luolc/AdaBound
162- Example :
163- from pytorch_optimizer import AdaBoundW
164- ...
165- model = YourModel()
166- optimizer = AdaBoundW(model.parameters())
167- ...
168- for input, output in data:
169- optimizer.zero_grad()
170- loss = loss_function(output, model(input))
171- loss.backward()
172- optimizer.step()
173- """
174-
175- def __init__ (
176- self ,
177- params : PARAMS ,
178- lr : float = 1e-3 ,
179- betas : BETAS = (0.9 , 0.999 ),
180- final_lr : float = 0.1 ,
181- gamma : float = 1e-3 ,
182- eps : float = 1e-8 ,
183- weight_decay : float = 0.0 ,
184- amsbound : bool = False ,
185- ):
186- """AdaBound optimizer with decoupled weight decay
187- :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
188- :param lr: float. learning rate
189- :param final_lr: float. final learning rate
190- :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
191- :param gamma: float. convergence speed of the bound functions
192- :param eps: float. term added to the denominator to improve numerical stability
193- :param weight_decay: float. weight decay (L2 penalty)
194- :param amsbound: bool. whether to use the AMSBound variant
195- """
196- self .lr = lr
197- self .betas = betas
198- self .eps = eps
199- self .weight_decay = weight_decay
200-
201- defaults : DEFAULT_PARAMETERS = dict (
202- lr = lr ,
203- betas = betas ,
204- final_lr = final_lr ,
205- gamma = gamma ,
206- eps = eps ,
207- weight_decay = weight_decay ,
208- amsbound = amsbound ,
209- )
210- super ().__init__ (params , defaults )
211-
212- self .base_lrs = [group ['lr' ] for group in self .param_groups ]
213-
214- def check_valid_parameters (self ):
215- if 0.0 > self .lr :
216- raise ValueError (f'Invalid learning rate : { self .lr } ' )
217- if 0.0 > self .eps :
218- raise ValueError (f'Invalid eps : { self .eps } ' )
219- if 0.0 > self .weight_decay :
220- raise ValueError (f'Invalid weight_decay : { self .weight_decay } ' )
221- if not 0.0 <= self .betas [0 ] < 1.0 :
222- raise ValueError (f'Invalid beta_0 : { self .betas [0 ]} ' )
223- if not 0.0 <= self .betas [1 ] < 1.0 :
224- raise ValueError (f'Invalid beta_1 : { self .betas [1 ]} ' )
225-
226- def __setstate__ (self , state : STATE ):
227- super ().__setstate__ (state )
228- for group in self .param_groups :
229- group .setdefault ('amsbound' , False )
230-
231- def step (self , closure : CLOSURE = None ) -> LOSS :
232- loss : LOSS = None
233- if closure is not None :
234- loss = closure ()
235-
236- for group , base_lr in zip (self .param_groups , self .base_lrs ):
237- for p in group ['params' ]:
238- if p .grad is None :
239- continue
240-
241- p .mul_ (1 - base_lr * group ['weight_decay' ])
242-
243- grad = p .grad .data
244- if grad .is_sparse :
245- raise RuntimeError (
246- 'AdaBound does not support sparse gradients'
247- )
248-
249- amsbound = group ['amsbound' ]
250-
251- state = self .state [p ]
252-
253- if len (state ) == 0 :
254- state ['step' ] = 0
255- state ['exp_avg' ] = torch .zeros_like (p )
256- state ['exp_avg_sq' ] = torch .zeros_like (p )
257- if amsbound :
258- state ['max_exp_avg_sq' ] = torch .zeros_like (p )
259-
260- exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
261- if amsbound :
262- max_exp_avg_sq = state ['max_exp_avg_sq' ]
263- beta1 , beta2 = group ['betas' ]
264-
265- state ['step' ] += 1
134+ if group ['weight_decay' ] != 0 :
135+ grad .add_ (p .data , alpha = group ['weight_decay' ])
266136
267137 # Decay the first and second moment running average coefficient
268138 exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
0 commit comments