@@ -588,6 +588,93 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
588588 )
589589
590590
591+ class AdaBelief (Optimizer ):
592+ r"""The AdaBelief optimizer [1]. AdaBelief adapts step sizes by the "belief"
593+ in observed gradients. It uses the variance of the prediction error rather
594+ than the gradient itself for the second moment estimate.
595+
596+ [1]: Zhuang, J., Tang, T., Ding, Y., Tatikonda, S., Dvornek, N.,
597+ Papademetris, X. and Duncan, J.S., 2020. AdaBelief optimizer: Adapting
598+ stepsizes by the belief in observed gradients. NeurIPS 2020.
599+
600+ .. math::
601+
602+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
603+ s_{t+1} &= \beta_2 s_t + (1 - \beta_2) (g_t - m_{t+1})^2 \\
604+ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{s_{t+1}} + \epsilon}
605+
606+ Args:
607+ learning_rate (float or callable): The learning rate :math:`\lambda`.
608+ betas (Tuple[float, float], optional): The coefficients
609+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
610+ gradient and its variance. Default: ``(0.9, 0.999)``
611+ eps (float, optional): The term :math:`\epsilon` added to the
612+ denominator to improve numerical stability. Default: ``1e-16``
613+ weight_decay (float, optional): The weight decay (L2 penalty).
614+ Default: ``0.0``
615+ bias_correction (bool, optional): If set to ``True``, bias correction
616+ is applied. Default: ``True``
617+ """
618+
619+ def __init__ (
620+ self ,
621+ learning_rate : Union [float , Callable [[mx .array ], mx .array ]],
622+ betas : List [float ] = [0.9 , 0.999 ],
623+ eps : float = 1e-16 ,
624+ weight_decay : float = 0.0 ,
625+ bias_correction : bool = True ,
626+ ):
627+ super ().__init__ ()
628+
629+ self ._maybe_schedule ("learning_rate" , learning_rate )
630+ self .betas = betas
631+ self .eps = eps
632+ self .weight_decay = weight_decay
633+ self .bias_correction = bias_correction
634+
635+ def init_single (self , parameter : mx .array , state : dict ):
636+ """Initialize optimizer state"""
637+ state ["m" ] = mx .zeros_like (parameter )
638+ state ["s" ] = mx .zeros_like (parameter )
639+
640+ def apply_single (self , gradient : mx .array , parameter : mx .array , state : dict ):
641+ """Performs the AdaBelief parameter update and stores :math:`m` and
642+ :math:`s` in the optimizer state."""
643+ lr = self .learning_rate .astype (gradient .dtype )
644+ b1 , b2 = self .betas
645+ eps = mx .array (self .eps , gradient .dtype )
646+ bias_correction = self .bias_correction
647+ step = self .step
648+
649+ m = state ["m" ]
650+ s = state ["s" ]
651+
652+ m = b1 * m + (1 - b1 ) * gradient
653+
654+ grad_residual = gradient - m
655+ s = b2 * s + (1 - b2 ) * mx .square (grad_residual )
656+
657+ state ["m" ] = m
658+ state ["s" ] = s
659+
660+ if bias_correction :
661+ bias_correction_1 = (1 - b1 ** step ).astype (gradient .dtype )
662+ bias_correction_2 = (1 - b2 ** step ).astype (gradient .dtype )
663+ step_size = lr / bias_correction_1
664+ bias_correction_2_sqrt = mx .sqrt (bias_correction_2 )
665+ denominator = (mx .sqrt (s ) / bias_correction_2_sqrt ) + eps
666+ else :
667+ step_size = lr
668+ denominator = mx .sqrt (s ) + eps
669+
670+ update = step_size * m / denominator
671+
672+ if self .weight_decay > 0 :
673+ parameter = parameter - lr * self .weight_decay * parameter
674+
675+ return parameter - update
676+
677+
591678class Adamax (Adam ):
592679 r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
593680
0 commit comments