Skip to content

Commit 00400de

Browse files
committed
feature: implement SimplifiedAdEMAMix optimizer
1 parent b672031 commit 00400de

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed

pytorch_optimizer/optimizer/ademamix.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,137 @@ def step(self, closure: CLOSURE = None) -> LOSS:
161161
p.add_(update, alpha=-step_size)
162162

163163
return loss
164+
165+
166+
class SimplifiedAdEMAMix(BaseOptimizer):
167+
r"""Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants.
168+
169+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
170+
:param lr: float. learning rate.
171+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
172+
:param alpha: float. coefficient for mixing the current gradient and EMA.
173+
:param beta1_warmup: Optional[int]. number of warmup steps used to increase beta1.
174+
:param min_beta1: float. minimum value of beta1 to start from.
175+
:param weight_decay: float. weight decay (L2 penalty).
176+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
177+
:param fixed_decay: bool. fix weight decay.
178+
:param eps: float. term added to the denominator to improve numerical stability.
179+
"""
180+
181+
def __init__(
182+
self,
183+
params: PARAMETERS,
184+
lr: float = 1e-4,
185+
betas: BETAS = (0.99, 0.95),
186+
weight_decay: float = 0.0,
187+
weight_decouple: bool = True,
188+
fixed_decay: bool = False,
189+
alpha: float = 0.0,
190+
beta1_warmup: Optional[int] = None,
191+
min_beta1: float = 0.9,
192+
eps: float = 1e-8,
193+
**kwargs,
194+
):
195+
self.validate_learning_rate(lr)
196+
self.validate_betas(betas)
197+
self.validate_non_negative(alpha, 'alpha')
198+
self.validate_non_negative(min_beta1, 'min_beta1')
199+
self.validate_non_negative(weight_decay, 'weight_decay')
200+
self.validate_non_negative(eps, 'eps')
201+
202+
defaults: DEFAULTS = {
203+
'lr': lr,
204+
'betas': betas,
205+
'alpha': alpha,
206+
'beta1_warmup': beta1_warmup,
207+
'min_beta1': min_beta1,
208+
'weight_decay': weight_decay,
209+
'weight_decouple': weight_decouple,
210+
'fixed_decay': fixed_decay,
211+
'eps': eps,
212+
}
213+
214+
super().__init__(params, defaults)
215+
216+
def __str__(self) -> str:
217+
return 'SimAdEMAMix'
218+
219+
@torch.no_grad()
220+
def reset(self):
221+
pass
222+
223+
@staticmethod
224+
def linear_hl_warmup_scheduler(step: int, beta_end: float, beta_start: float = 0.0, warmup: int = 1) -> float:
225+
226+
def f(beta: float, eps: float = 1e-8) -> float:
227+
return math.log(0.5) / math.log(beta + eps) - 1.0
228+
229+
def f_inv(t: float) -> float:
230+
return math.pow(0.5, 1.0 / (t + 1))
231+
232+
if step < warmup:
233+
a: float = step / float(warmup)
234+
return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end))
235+
236+
return beta_end
237+
238+
@torch.no_grad()
239+
def step(self, closure: CLOSURE = None) -> LOSS:
240+
loss: LOSS = None
241+
if closure is not None:
242+
with torch.enable_grad():
243+
loss = closure()
244+
245+
for group in self.param_groups:
246+
if 'step' in group:
247+
group['step'] += 1
248+
else:
249+
group['step'] = 1
250+
251+
beta1, beta2 = group['betas']
252+
253+
if group['beta1_warmup']:
254+
beta1 = self.linear_hl_warmup_scheduler(
255+
group['step'], beta_end=beta1, beta_start=group['min_beta1'], warmup=group['beta1_warmup']
256+
)
257+
258+
for p in group['params']:
259+
if p.grad is None:
260+
continue
261+
262+
grad = p.grad
263+
if grad.is_sparse:
264+
raise NoSparseGradientError(str(self))
265+
266+
state = self.state[p]
267+
268+
if len(state) == 0:
269+
state['exp_avg'] = torch.zeros_like(p)
270+
state['exp_avg_sq'] = torch.zeros_like(p)
271+
state['num_sum'] = 0.0
272+
state['den_sum'] = 0.0
273+
274+
self.apply_weight_decay(
275+
p=p,
276+
grad=grad,
277+
lr=group['lr'],
278+
weight_decay=group['weight_decay'],
279+
weight_decouple=group['weight_decouple'],
280+
fixed_decay=group['fixed_decay'],
281+
)
282+
283+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
284+
285+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
286+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
287+
288+
state['num_sum'] = beta1 * state['num_sum'] + 1.0
289+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
290+
291+
de_nom = exp_avg_sq.sqrt().add_(math.sqrt(state['den_sum']) * group['eps'])
292+
293+
update = (group['alpha'] * grad + exp_avg).div_(de_nom).div_(math.sqrt(state['den_sum']))
294+
295+
p.add_(update, alpha=-group['lr'])
296+
297+
return loss

0 commit comments

Comments
 (0)