55from torch .nn import functional as f
66from torch .optim import Optimizer
77
8- from pytorch_optimizer .base .exception import NegativeLRError , NoSparseGradientError , ZeroParameterSizeError
8+ from pytorch_optimizer .base .exception import NoSparseGradientError , ZeroParameterSizeError
99from pytorch_optimizer .base .optimizer import BaseOptimizer
1010from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
1111from pytorch_optimizer .optimizer .agc import agc
@@ -73,6 +73,7 @@ def __init__( # pylint: disable=R0913
7373 eps : float = 1e-8 ,
7474 ):
7575 self .lr = lr
76+ self .min_lr = warm_down_min_lr
7677 self .beta0 = beta0
7778 self .betas = betas
7879 self .use_softplus = use_softplus
@@ -96,7 +97,6 @@ def __init__( # pylint: disable=R0913
9697 # learning rate
9798 self .starting_lr = lr
9899 self .current_lr = lr
99- self .min_lr = warm_down_min_lr
100100
101101 defaults : DEFAULTS = {
102102 'lr' : lr ,
@@ -123,6 +123,7 @@ def __init__( # pylint: disable=R0913
123123
124124 def validate_parameters (self ):
125125 self .validate_learning_rate (self .lr )
126+ self .validate_learning_rate (self .min_lr )
126127 self .validate_betas (self .betas )
127128 self .validate_beta0 (self .beta0 )
128129 self .validate_weight_decay (self .weight_decay )
@@ -169,7 +170,7 @@ def warm_up_dampening(self, lr: float, step: int) -> float:
169170
170171 return new_lr
171172
172- def get_warm_down (self , lr : float , iteration : int ) -> float :
173+ def warm_down (self , lr : float , iteration : int ) -> float :
173174 if iteration < self .start_warm_down :
174175 return lr
175176
@@ -182,6 +183,7 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
182183
183184 new_lr : float = self .starting_lr - self .warm_down_lr_delta * warm_down_pct
184185 new_lr = max (new_lr , self .min_lr )
186+
185187 self .current_lr = new_lr
186188
187189 return new_lr
@@ -198,6 +200,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
198200
199201 # Phase 1 - Accumulate all the variance_ma_sum to use in stable weight decay
200202 for group in self .param_groups :
203+ if 'step' in group :
204+ group ['step' ] += 1
205+ else :
206+ group ['step' ] = 1
207+
201208 beta1 , beta2 = group ['betas' ]
202209 for p in group ['params' ]:
203210 if p .grad is None :
@@ -214,7 +221,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
214221
215222 state = self .state [p ]
216223 if len (state ) == 0 :
217- state ['step' ] = 0
218224 state ['grad_ma' ] = torch .zeros_like (p )
219225 state ['variance_ma' ] = torch .zeros_like (p )
220226 state ['lookahead_params' ] = torch .empty_like (p )
@@ -227,17 +233,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
227233 grad = centralize_gradient (grad , gc_conv_only = False )
228234 grad = normalize_gradient (grad )
229235
230- state ['step' ] += 1
231-
232- bias_correction2 = 1.0 - beta2 ** state ['step' ]
236+ bias_correction2 = 1.0 - beta2 ** group ['step' ]
233237
234238 # second moment estimation
235239 # using positive-negative momentum and bias correction
236240 variance_ma = state ['variance_ma' ]
237241 variance_ma .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
238242 variance_ma_sum += (variance_ma / bias_correction2 ).sum ()
239243
240- # stable weight decay
241244 if param_size == 0 :
242245 raise ZeroParameterSizeError ()
243246
@@ -247,42 +250,32 @@ def step(self, closure: CLOSURE = None) -> LOSS:
247250
248251 # Phase 2 - Apply weight decay and step
249252 for group in self .param_groups :
250- if len (self .state ) == 0 :
251- continue
252- p = next (iter (self .state .keys ()))
253- lr = group ["lr" ]
254- step = self .state [p ]["step" ]
255-
253+ lr = group ['lr' ]
256254 beta1 , beta2 = group ['betas' ]
257- bias_correction1 = 1.0 - beta1 ** step # fmt: skip
258- bias_correction2_sq = math .sqrt (1.0 - beta2 ** step ) # fmt: skip
259255
260- noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 ) # fmt: skip
261-
262- # warm up
263- lr = self .warm_up_dampening (lr , step )
264-
265- # warm down
266- lr = self .get_warm_down (lr , step )
267- if lr < 0.0 :
268- raise NegativeLRError (lr )
256+ bias_correction1 = 1.0 - beta1 ** group ['step' ] # fmt: skip
257+ bias_correction2_sq = math .sqrt (1.0 - beta2 ** group ['step' ]) # fmt: skip
269258
270- # stable decay
271- decay = group ['weight_decay' ]
272- if decay :
273- p .mul_ (1.0 - decay * lr / variance_normalized )
259+ noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 ) # fmt: skip
274260
275- # norm loss
276- u_norm = unit_norm (p )
277- correction = 2.0 * self .norm_loss_factor * (1.0 - torch .div (1 , u_norm + self .eps ))
278- p .mul_ (1.0 - lr * correction )
261+ # warm up & down
262+ lr = self .warm_up_dampening (lr , group ['step' ])
263+ lr = self .warm_down (lr , group ['step' ])
279264
280265 for p in group ['params' ]:
281266 if p .grad is None :
282267 continue
283268
269+ # stable weight decay
270+ if group ['weight_decay' ]:
271+ p .mul_ (1.0 - group ['weight_decay' ] * lr / variance_normalized )
272+
273+ # norm loss
274+ correction = 2.0 * self .norm_loss_factor * (1.0 - torch .div (1 , unit_norm (p ) + self .eps ))
275+ p .mul_ (1.0 - lr * correction )
276+
284277 state = self .state [p ]
285- if state ['step' ] % 2 == 1 :
278+ if group ['step' ] % 2 == 1 :
286279 grad_ma , neg_grad_ma = state ['grad_ma' ], state ['neg_grad_ma' ]
287280 else :
288281 grad_ma , neg_grad_ma = state ['neg_grad_ma' ], state ['grad_ma' ]
0 commit comments