Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/python/optimizers/common_optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Common Optimizers
AdaDelta
Adam
AdamW
AdaBelief
Adamax
Lion
MultiOptimizer
Expand Down
87 changes: 87 additions & 0 deletions python/mlx/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,93 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
)


class AdaBelief(Optimizer):
r"""The AdaBelief optimizer [1]. AdaBelief adapts step sizes by the "belief"
in observed gradients. It uses the variance of the prediction error rather
than the gradient itself for the second moment estimate.

[1]: Zhuang, J., Tang, T., Ding, Y., Tatikonda, S., Dvornek, N.,
Papademetris, X. and Duncan, J.S., 2020. AdaBelief optimizer: Adapting
stepsizes by the belief in observed gradients. NeurIPS 2020.

.. math::

m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
s_{t+1} &= \beta_2 s_t + (1 - \beta_2) (g_t - m_{t+1})^2 \\
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{s_{t+1}} + \epsilon}

Args:
learning_rate (float or callable): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its variance. Default: ``(0.9, 0.999)``
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-16``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.0``
bias_correction (bool, optional): If set to ``True``, bias correction
is applied. Default: ``True``
"""

def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999],
eps: float = 1e-16,
weight_decay: float = 0.0,
bias_correction: bool = True,
):
super().__init__()

self._maybe_schedule("learning_rate", learning_rate)
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
self.bias_correction = bias_correction

def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = mx.zeros_like(parameter)
state["s"] = mx.zeros_like(parameter)

def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the AdaBelief parameter update and stores :math:`m` and
:math:`s` in the optimizer state."""
lr = self.learning_rate.astype(gradient.dtype)
b1, b2 = self.betas
eps = mx.array(self.eps, gradient.dtype)
bias_correction = self.bias_correction
step = self.step

m = state["m"]
s = state["s"]

m = b1 * m + (1 - b1) * gradient

grad_residual = gradient - m
s = b2 * s + (1 - b2) * mx.square(grad_residual)

state["m"] = m
state["s"] = s

if bias_correction:
bias_correction_1 = (1 - b1**step).astype(gradient.dtype)
bias_correction_2 = (1 - b2**step).astype(gradient.dtype)
step_size = lr / bias_correction_1
bias_correction_2_sqrt = mx.sqrt(bias_correction_2)
denominator = (mx.sqrt(s) / bias_correction_2_sqrt) + eps
else:
step_size = lr
denominator = mx.sqrt(s) + eps

update = step_size * m / denominator

if self.weight_decay > 0:
parameter = parameter - lr * self.weight_decay * parameter

return parameter - update


class Adamax(Adam):
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].

Expand Down
34 changes: 34 additions & 0 deletions python/tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,40 @@ def loss_fn(model, x, y):
torch_param = param.data.detach().numpy()
self.assertTrue(np.allclose(torch_param, mlx_param))

def test_adabelief(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)

# Explicit init
optim = opt.AdaBelief(learning_rate=1e-2)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["s"], mx.zeros_like(p)),
params,
optim.state,
)
)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)),
params,
optim.state,
)
)

# Implicit init
optim = opt.AdaBelief(learning_rate=1e-2, betas=[0.9, 0.999])
optim.apply_gradients(grads, params)
self.assertTrue(
tree_equal(
lambda g, s: mx.allclose(s["m"], (1 - 0.9) * g), grads, optim.state
)
)

def test_lion(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
Expand Down