Skip to content

Commit efebcc5

Browse files
Add AdaBelief optimizer
- Implement AdaBelief optimizer following the NeurIPS 2020 paper - Add unit tests following MLX standard patterns - Update optimizer documentation to include AdaBelief Fixes #2479
1 parent 895217f commit efebcc5

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

docs/src/python/optimizers/common_optimizers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Common Optimizers
1616
AdaDelta
1717
Adam
1818
AdamW
19+
AdaBelief
1920
Adamax
2021
Lion
2122
MultiOptimizer

python/mlx/optimizers/optimizers.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,93 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
588588
)
589589

590590

591+
class AdaBelief(Optimizer):
592+
r"""The AdaBelief optimizer [1]. AdaBelief adapts step sizes by the "belief"
593+
in observed gradients. It uses the variance of the prediction error rather
594+
than the gradient itself for the second moment estimate.
595+
596+
[1]: Zhuang, J., Tang, T., Ding, Y., Tatikonda, S., Dvornek, N.,
597+
Papademetris, X. and Duncan, J.S., 2020. AdaBelief optimizer: Adapting
598+
stepsizes by the belief in observed gradients. NeurIPS 2020.
599+
600+
.. math::
601+
602+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
603+
s_{t+1} &= \beta_2 s_t + (1 - \beta_2) (g_t - m_{t+1})^2 \\
604+
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{s_{t+1}} + \epsilon}
605+
606+
Args:
607+
learning_rate (float or callable): The learning rate :math:`\lambda`.
608+
betas (Tuple[float, float], optional): The coefficients
609+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
610+
gradient and its variance. Default: ``(0.9, 0.999)``
611+
eps (float, optional): The term :math:`\epsilon` added to the
612+
denominator to improve numerical stability. Default: ``1e-16``
613+
weight_decay (float, optional): The weight decay (L2 penalty).
614+
Default: ``0.0``
615+
bias_correction (bool, optional): If set to ``True``, bias correction
616+
is applied. Default: ``True``
617+
"""
618+
619+
def __init__(
620+
self,
621+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
622+
betas: List[float] = [0.9, 0.999],
623+
eps: float = 1e-16,
624+
weight_decay: float = 0.0,
625+
bias_correction: bool = True,
626+
):
627+
super().__init__()
628+
629+
self._maybe_schedule("learning_rate", learning_rate)
630+
self.betas = betas
631+
self.eps = eps
632+
self.weight_decay = weight_decay
633+
self.bias_correction = bias_correction
634+
635+
def init_single(self, parameter: mx.array, state: dict):
636+
"""Initialize optimizer state"""
637+
state["m"] = mx.zeros_like(parameter)
638+
state["s"] = mx.zeros_like(parameter)
639+
640+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
641+
"""Performs the AdaBelief parameter update and stores :math:`m` and
642+
:math:`s` in the optimizer state."""
643+
lr = self.learning_rate.astype(gradient.dtype)
644+
b1, b2 = self.betas
645+
eps = mx.array(self.eps, gradient.dtype)
646+
bias_correction = self.bias_correction
647+
step = self.step
648+
649+
m = state["m"]
650+
s = state["s"]
651+
652+
m = b1 * m + (1 - b1) * gradient
653+
654+
grad_residual = gradient - m
655+
s = b2 * s + (1 - b2) * mx.square(grad_residual)
656+
657+
state["m"] = m
658+
state["s"] = s
659+
660+
if bias_correction:
661+
bias_correction_1 = (1 - b1**step).astype(gradient.dtype)
662+
bias_correction_2 = (1 - b2**step).astype(gradient.dtype)
663+
step_size = lr / bias_correction_1
664+
bias_correction_2_sqrt = mx.sqrt(bias_correction_2)
665+
denominator = (mx.sqrt(s) / bias_correction_2_sqrt) + eps
666+
else:
667+
step_size = lr
668+
denominator = mx.sqrt(s) + eps
669+
670+
update = step_size * m / denominator
671+
672+
if self.weight_decay > 0:
673+
parameter = parameter - lr * self.weight_decay * parameter
674+
675+
return parameter - update
676+
677+
591678
class Adamax(Adam):
592679
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
593680

python/tests/test_optimizers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,40 @@ def loss_fn(model, x, y):
248248
torch_param = param.data.detach().numpy()
249249
self.assertTrue(np.allclose(torch_param, mlx_param))
250250

251+
def test_adabelief(self):
252+
params = {
253+
"first": [mx.zeros((10,)), mx.zeros((1,))],
254+
"second": mx.zeros((1,)),
255+
}
256+
grads = tree_map(lambda x: mx.ones_like(x), params)
257+
258+
# Explicit init
259+
optim = opt.AdaBelief(learning_rate=1e-2)
260+
optim.init(params)
261+
self.assertTrue(
262+
tree_equal(
263+
lambda p, s: mx.array_equal(s["s"], mx.zeros_like(p)),
264+
params,
265+
optim.state,
266+
)
267+
)
268+
self.assertTrue(
269+
tree_equal(
270+
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
271+
params,
272+
optim.state,
273+
)
274+
)
275+
276+
# Implicit init
277+
optim = opt.AdaBelief(learning_rate=1e-2, betas=[0.9, 0.999])
278+
optim.apply_gradients(grads, params)
279+
self.assertTrue(
280+
tree_equal(
281+
lambda g, s: mx.allclose(s["m"], (1 - 0.9) * g), grads, optim.state
282+
)
283+
)
284+
251285
def test_lion(self):
252286
params = {
253287
"first": [mx.zeros((10,)), mx.zeros((1,))],

0 commit comments

Comments
 (0)