@@ -161,3 +161,137 @@ def step(self, closure: CLOSURE = None) -> LOSS:
161161 p .add_ (update , alpha = - step_size )
162162
163163 return loss
164+
165+
166+ class SimplifiedAdEMAMix (BaseOptimizer ):
167+ r"""Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants.
168+
169+ :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
170+ :param lr: float. learning rate.
171+ :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
172+ :param alpha: float. coefficient for mixing the current gradient and EMA.
173+ :param beta1_warmup: Optional[int]. number of warmup steps used to increase beta1.
174+ :param min_beta1: float. minimum value of beta1 to start from.
175+ :param weight_decay: float. weight decay (L2 penalty).
176+ :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
177+ :param fixed_decay: bool. fix weight decay.
178+ :param eps: float. term added to the denominator to improve numerical stability.
179+ """
180+
181+ def __init__ (
182+ self ,
183+ params : PARAMETERS ,
184+ lr : float = 1e-4 ,
185+ betas : BETAS = (0.99 , 0.95 ),
186+ weight_decay : float = 0.0 ,
187+ weight_decouple : bool = True ,
188+ fixed_decay : bool = False ,
189+ alpha : float = 0.0 ,
190+ beta1_warmup : Optional [int ] = None ,
191+ min_beta1 : float = 0.9 ,
192+ eps : float = 1e-8 ,
193+ ** kwargs ,
194+ ):
195+ self .validate_learning_rate (lr )
196+ self .validate_betas (betas )
197+ self .validate_non_negative (alpha , 'alpha' )
198+ self .validate_non_negative (min_beta1 , 'min_beta1' )
199+ self .validate_non_negative (weight_decay , 'weight_decay' )
200+ self .validate_non_negative (eps , 'eps' )
201+
202+ defaults : DEFAULTS = {
203+ 'lr' : lr ,
204+ 'betas' : betas ,
205+ 'alpha' : alpha ,
206+ 'beta1_warmup' : beta1_warmup ,
207+ 'min_beta1' : min_beta1 ,
208+ 'weight_decay' : weight_decay ,
209+ 'weight_decouple' : weight_decouple ,
210+ 'fixed_decay' : fixed_decay ,
211+ 'eps' : eps ,
212+ }
213+
214+ super ().__init__ (params , defaults )
215+
216+ def __str__ (self ) -> str :
217+ return 'SimAdEMAMix'
218+
219+ @torch .no_grad ()
220+ def reset (self ):
221+ pass
222+
223+ @staticmethod
224+ def linear_hl_warmup_scheduler (step : int , beta_end : float , beta_start : float = 0.0 , warmup : int = 1 ) -> float :
225+
226+ def f (beta : float , eps : float = 1e-8 ) -> float :
227+ return math .log (0.5 ) / math .log (beta + eps ) - 1.0
228+
229+ def f_inv (t : float ) -> float :
230+ return math .pow (0.5 , 1.0 / (t + 1 ))
231+
232+ if step < warmup :
233+ a : float = step / float (warmup )
234+ return f_inv ((1.0 - a ) * f (beta_start ) + a * f (beta_end ))
235+
236+ return beta_end
237+
238+ @torch .no_grad ()
239+ def step (self , closure : CLOSURE = None ) -> LOSS :
240+ loss : LOSS = None
241+ if closure is not None :
242+ with torch .enable_grad ():
243+ loss = closure ()
244+
245+ for group in self .param_groups :
246+ if 'step' in group :
247+ group ['step' ] += 1
248+ else :
249+ group ['step' ] = 1
250+
251+ beta1 , beta2 = group ['betas' ]
252+
253+ if group ['beta1_warmup' ]:
254+ beta1 = self .linear_hl_warmup_scheduler (
255+ group ['step' ], beta_end = beta1 , beta_start = group ['min_beta1' ], warmup = group ['beta1_warmup' ]
256+ )
257+
258+ for p in group ['params' ]:
259+ if p .grad is None :
260+ continue
261+
262+ grad = p .grad
263+ if grad .is_sparse :
264+ raise NoSparseGradientError (str (self ))
265+
266+ state = self .state [p ]
267+
268+ if len (state ) == 0 :
269+ state ['exp_avg' ] = torch .zeros_like (p )
270+ state ['exp_avg_sq' ] = torch .zeros_like (p )
271+ state ['num_sum' ] = 0.0
272+ state ['den_sum' ] = 0.0
273+
274+ self .apply_weight_decay (
275+ p = p ,
276+ grad = grad ,
277+ lr = group ['lr' ],
278+ weight_decay = group ['weight_decay' ],
279+ weight_decouple = group ['weight_decouple' ],
280+ fixed_decay = group ['fixed_decay' ],
281+ )
282+
283+ exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
284+
285+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
286+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
287+
288+ state ['num_sum' ] = beta1 * state ['num_sum' ] + 1.0
289+ state ['den_sum' ] = beta2 * state ['den_sum' ] + (1.0 - beta2 )
290+
291+ de_nom = exp_avg_sq .sqrt ().add_ (math .sqrt (state ['den_sum' ]) * group ['eps' ])
292+
293+ update = (group ['alpha' ] * grad + exp_avg ).div_ (de_nom ).div_ (math .sqrt (state ['den_sum' ]))
294+
295+ p .add_ (update , alpha = - group ['lr' ])
296+
297+ return loss
0 commit comments