@@ -96,12 +96,9 @@ def __init__( # pylint: disable=R0913
9696 self .lookahead_blending_alpha = lookahead_blending_alpha
9797 self .norm_loss_factor = norm_loss_factor
9898
99- # lookahead
10099 self .lookahead_step : int = 0
101-
102- # learning rate
103- self .starting_lr = lr
104- self .current_lr = lr
100+ self .starting_lr : float = lr
101+ self .current_lr : float = lr
105102
106103 defaults : DEFAULTS = {
107104 'lr' : lr ,
@@ -114,7 +111,6 @@ def __init__( # pylint: disable=R0913
114111 }
115112 super ().__init__ (params , defaults )
116113
117- # warmup iterations
118114 self .num_warm_up_iterations : int = (
119115 self .build_warm_up_iterations (num_iterations , betas [1 ])
120116 if num_warm_up_iterations is None
@@ -140,8 +136,7 @@ def reset(self):
140136
141137 state ['grad_ma' ] = torch .zeros_like (p )
142138 state ['variance_ma' ] = torch .zeros_like (p )
143- state ['lookahead_params' ] = torch .empty_like (p )
144- state ['lookahead_params' ].copy_ (p )
139+ state ['lookahead_params' ] = p .clone ()
145140 state ['neg_grad_ma' ] = torch .zeros_like (p )
146141 state ['max_variance_ma' ] = torch .zeros_like (p )
147142
@@ -162,28 +157,21 @@ def warm_up_dampening(self, lr: float, step: int) -> float:
162157
163158 warm_up_current_pct : float = min (1.0 , (step / self .num_warm_up_iterations ))
164159
165- new_lr : float = lr * warm_up_current_pct
166- self .current_lr = new_lr
160+ self .current_lr = lr * warm_up_current_pct
167161
168- return new_lr
162+ return self . current_lr
169163
170164 def warm_down (self , lr : float , iteration : int ) -> float :
171165 if iteration < self .start_warm_down :
172166 return lr
173167
174168 # start iteration from 1, not 0
175- warm_down_iteration : int = (iteration + 1 ) - self .start_warm_down
176- warm_down_iteration = max (warm_down_iteration , 1 )
177-
178- warm_down_pct : float = warm_down_iteration / (self .num_warm_down_iterations + 1 )
179- warm_down_pct = min (warm_down_pct , 1.0 )
180-
181- new_lr : float = self .starting_lr - self .warm_down_lr_delta * warm_down_pct
182- new_lr = max (new_lr , self .min_lr )
169+ warm_down_iteration : int = max ((iteration + 1 ) - self .start_warm_down , 1 )
170+ warm_down_pct : float = min (warm_down_iteration / (self .num_warm_down_iterations + 1 ), 1.0 )
183171
184- self .current_lr = new_lr
172+ self .current_lr = max ( self . starting_lr - self . warm_down_lr_delta * warm_down_pct , self . min_lr )
185173
186- return new_lr
174+ return self . current_lr
187175
188176 @torch .no_grad ()
189177 def step (self , closure : CLOSURE = None ) -> LOSS :
@@ -220,8 +208,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
220208 if len (state ) == 0 :
221209 state ['grad_ma' ] = torch .zeros_like (p )
222210 state ['variance_ma' ] = torch .zeros_like (p )
223- state ['lookahead_params' ] = torch .empty_like (p )
224- state ['lookahead_params' ].copy_ (p )
211+ state ['lookahead_params' ] = p .clone ()
225212 state ['neg_grad_ma' ] = torch .zeros_like (p )
226213 state ['max_variance_ma' ] = torch .zeros_like (p )
227214
@@ -245,7 +232,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
245232
246233 # Phase 2 - Apply weight decay and step
247234 for group in self .param_groups :
248- lr : float = group ['lr' ]
249235 beta1 , beta2 = group ['betas' ]
250236
251237 bias_correction1 : float = 1.0 - beta1 ** group ['step' ] # fmt: skip
@@ -254,7 +240,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
254240 noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 ) # fmt: skip
255241
256242 # warm up & down
257- lr = self .warm_up_dampening (lr , group ['step' ])
243+ lr : float = self .warm_up_dampening (group [ 'lr' ] , group ['step' ])
258244 lr = self .warm_down (lr , group ['step' ])
259245
260246 for p in group ['params' ]:
@@ -287,16 +273,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
287273
288274 de_nom = (variance_ma .sqrt () / bias_correction2_sq ).add_ (group ['eps' ])
289275
276+ if self .use_softplus :
277+ de_nom = f .softplus (de_nom , beta = self .beta_softplus )
278+
290279 grad = p .grad
291280 centralize_gradient (grad , gc_conv_only = False )
292281 normalize_gradient (grad )
293282
294283 grad_ma .mul_ (beta1 ** 2 ).add_ (grad , alpha = 1.0 - beta1 ** 2 ) # fmt: skip
295284
296- step_size : float = lr if group ['adam_debias' ] else lr / bias_correction1
297-
298- if self .use_softplus :
299- de_nom = f .softplus (de_nom , beta = self .beta_softplus )
285+ step_size : float = self .apply_adam_debias (group ['adam_debias' ], lr , bias_correction1 )
300286
301287 pn_momentum = grad_ma .mul (1.0 + 1.0 ).add (neg_grad_ma , alpha = - 1.0 ).mul (1.0 / noise_norm )
302288 p .addcdiv_ (pn_momentum , de_nom , value = - step_size )
0 commit comments