diff --git a/docs/src/python/optimizers/common_optimizers.rst b/docs/src/python/optimizers/common_optimizers.rst index 4975df541b..e75e4de10e 100644 --- a/docs/src/python/optimizers/common_optimizers.rst +++ b/docs/src/python/optimizers/common_optimizers.rst @@ -16,6 +16,7 @@ Common Optimizers AdaDelta Adam AdamW + AdaBelief Adamax Lion MultiOptimizer diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 2cc9e26b17..c7eddc2c7b 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -588,6 +588,93 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): ) +class AdaBelief(Optimizer): + r"""The AdaBelief optimizer [1]. AdaBelief adapts step sizes by the "belief" + in observed gradients. It uses the variance of the prediction error rather + than the gradient itself for the second moment estimate. + + [1]: Zhuang, J., Tang, T., Ding, Y., Tatikonda, S., Dvornek, N., + Papademetris, X. and Duncan, J.S., 2020. AdaBelief optimizer: Adapting + stepsizes by the belief in observed gradients. NeurIPS 2020. + + .. math:: + + m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ + s_{t+1} &= \beta_2 s_t + (1 - \beta_2) (g_t - m_{t+1})^2 \\ + w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{s_{t+1}} + \epsilon} + + Args: + learning_rate (float or callable): The learning rate :math:`\lambda`. + betas (Tuple[float, float], optional): The coefficients + :math:`(\beta_1, \beta_2)` used for computing running averages of the + gradient and its variance. Default: ``(0.9, 0.999)`` + eps (float, optional): The term :math:`\epsilon` added to the + denominator to improve numerical stability. Default: ``1e-16`` + weight_decay (float, optional): The weight decay (L2 penalty). + Default: ``0.0`` + bias_correction (bool, optional): If set to ``True``, bias correction + is applied. Default: ``True`` + """ + + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + betas: List[float] = [0.9, 0.999], + eps: float = 1e-16, + weight_decay: float = 0.0, + bias_correction: bool = True, + ): + super().__init__() + + self._maybe_schedule("learning_rate", learning_rate) + self.betas = betas + self.eps = eps + self.weight_decay = weight_decay + self.bias_correction = bias_correction + + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["m"] = mx.zeros_like(parameter) + state["s"] = mx.zeros_like(parameter) + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + """Performs the AdaBelief parameter update and stores :math:`m` and + :math:`s` in the optimizer state.""" + lr = self.learning_rate.astype(gradient.dtype) + b1, b2 = self.betas + eps = mx.array(self.eps, gradient.dtype) + bias_correction = self.bias_correction + step = self.step + + m = state["m"] + s = state["s"] + + m = b1 * m + (1 - b1) * gradient + + grad_residual = gradient - m + s = b2 * s + (1 - b2) * mx.square(grad_residual) + + state["m"] = m + state["s"] = s + + if bias_correction: + bias_correction_1 = (1 - b1**step).astype(gradient.dtype) + bias_correction_2 = (1 - b2**step).astype(gradient.dtype) + step_size = lr / bias_correction_1 + bias_correction_2_sqrt = mx.sqrt(bias_correction_2) + denominator = (mx.sqrt(s) / bias_correction_2_sqrt) + eps + else: + step_size = lr + denominator = mx.sqrt(s) + eps + + update = step_size * m / denominator + + if self.weight_decay > 0: + parameter = parameter - lr * self.weight_decay * parameter + + return parameter - update + + class Adamax(Adam): r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1]. diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 6869ac357a..4fe1bc88b4 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -248,6 +248,40 @@ def loss_fn(model, x, y): torch_param = param.data.detach().numpy() self.assertTrue(np.allclose(torch_param, mlx_param)) + def test_adabelief(self): + params = { + "first": [mx.zeros((10,)), mx.zeros((1,))], + "second": mx.zeros((1,)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.AdaBelief(learning_rate=1e-2) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["s"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["m"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Implicit init + optim = opt.AdaBelief(learning_rate=1e-2, betas=[0.9, 0.999]) + optim.apply_gradients(grads, params) + self.assertTrue( + tree_equal( + lambda g, s: mx.allclose(s["m"], (1 - 0.9) * g), grads, optim.state + ) + ) + def test_lion(self): params = { "first": [mx.zeros((10,)), mx.zeros((1,))],