|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +import math |
| 16 | +from typing import Optional, Tuple |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | + |
| 21 | +__all__ = [ |
| 22 | + "calculate_sim_ademamix_update", |
| 23 | + "calculate_ademamix_update", |
| 24 | +] |
| 25 | + |
| 26 | + |
| 27 | +@torch.compile # type: ignore[misc] |
| 28 | +@torch.no_grad() # type: ignore[misc] |
| 29 | +def calculate_sim_ademamix_update( |
| 30 | + grad: torch.Tensor, |
| 31 | + exp_avg: torch.Tensor, |
| 32 | + exp_avg_sq: torch.Tensor, |
| 33 | + num_beta_fast_warmup_steps: Optional[int], |
| 34 | + min_beta_fast: float, |
| 35 | + betas: Tuple[float, float], |
| 36 | + step: int, |
| 37 | + eps: float, |
| 38 | + correct_bias: bool, |
| 39 | + alpha: float = 2, |
| 40 | +) -> torch.Tensor: |
| 41 | + """Performs simplified AdEMAMix update. |
| 42 | +
|
| 43 | + This function performs the computation of 1 step of simplified AdEMAMix. |
| 44 | + Based on https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py |
| 45 | + and https://arxiv.org/abs/2409.03137. |
| 46 | +
|
| 47 | + The update rule is as follows: |
| 48 | +
|
| 49 | + .. math:: |
| 50 | + m_t = \\beta_{\\text{fast}} m_{t-1} + g_t \\\\ |
| 51 | + v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\ |
| 52 | + \\hat{m}_t = \\frac{m_t}{(1 - \\beta_{\\text{fast}}^t) / (1 - \\beta_{\\text{fast}})} \\\\ |
| 53 | + \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\ |
| 54 | + \\text{update} = \\frac{\\alpha g_t + \\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon} |
| 55 | +
|
| 56 | + Args: |
| 57 | + grad: The gradient tensor. |
| 58 | + exp_avg: The accumulated first moment of the gradient. |
| 59 | + exp_avg_sq: The accumulated second moment of the gradient. |
| 60 | + num_beta_fast_warmup_steps: Number of warmup steps used to increase beta_fast |
| 61 | + min_beta_fast: The minimum beta_fast value used at initialization |
| 62 | + betas: The EMA beta coefficients for the Adam update. |
| 63 | + step: The current step of the optimizer, used to compute the bias correction terms. |
| 64 | + eps: The epsilon for the Adam second moment update. |
| 65 | + correct_bias: Whether to correct the bias of the AdEMAMix update. |
| 66 | + alpha: Coeficient for mixing the current gradient and EMA. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + The simplified-AdEMAMix update. |
| 70 | + """ |
| 71 | + beta_fast_final, beta2 = betas |
| 72 | + |
| 73 | + # Compute beta_fast based on scheduler |
| 74 | + if num_beta_fast_warmup_steps is not None: |
| 75 | + beta_fast = _linear_half_life_warmup_scheduler( |
| 76 | + step, beta_end=beta_fast_final, beta_start=min_beta_fast, num_warmup_steps=num_beta_fast_warmup_steps |
| 77 | + ) |
| 78 | + else: |
| 79 | + beta_fast = beta_fast_final |
| 80 | + |
| 81 | + # Decay the first moment "theory style": https://arxiv.org/abs/2502.02431 |
| 82 | + exp_avg.mul_(beta_fast).add_(grad, alpha=1.0) |
| 83 | + |
| 84 | + # Decay the second moment exponential moving average |
| 85 | + exp_avg_sq.lerp_(grad.square(), 1 - beta2) |
| 86 | + |
| 87 | + if correct_bias: |
| 88 | + # theory style bias correction |
| 89 | + bias_correction1 = (1 - beta_fast**step) / (1 - beta_fast) |
| 90 | + bias_correction2 = 1 - beta2**step |
| 91 | + else: |
| 92 | + bias_correction1 = 1 |
| 93 | + bias_correction2 = 1 |
| 94 | + |
| 95 | + # step size correction for optimizer states EMA |
| 96 | + momentum = exp_avg / bias_correction1 |
| 97 | + adam_second_moment = exp_avg_sq / bias_correction2 |
| 98 | + adam_second_moment = adam_second_moment.sqrt() + eps |
| 99 | + |
| 100 | + return (alpha * grad + momentum) / adam_second_moment |
| 101 | + |
| 102 | + |
| 103 | +@torch.compile # type: ignore[misc] |
| 104 | +@torch.no_grad() # type: ignore[misc] |
| 105 | +def calculate_ademamix_update( |
| 106 | + grad: torch.Tensor, |
| 107 | + exp_avg_fast: torch.Tensor, |
| 108 | + exp_avg_slow: torch.Tensor, |
| 109 | + exp_avg_sq: torch.Tensor, |
| 110 | + num_beta_slow_warmup_steps: Optional[int], |
| 111 | + num_alpha_warmup_steps: Optional[int], |
| 112 | + betas: Tuple[float, float, float], |
| 113 | + step: int, |
| 114 | + eps: float, |
| 115 | + correct_bias: bool, |
| 116 | + alpha: float = 2, |
| 117 | +) -> torch.Tensor: |
| 118 | + """Performs AdEMAMix update. |
| 119 | +
|
| 120 | + This function performs the computation of 1 step of AdEMAMix. |
| 121 | + Based on https://github.com/apple/ml-ademamix/blob/main/pytorch/ademamix.py |
| 122 | + and https://arxiv.org/abs/2409.03137. |
| 123 | +
|
| 124 | + The update rule is as follows: |
| 125 | +
|
| 126 | + .. math:: |
| 127 | + m_t^{\\text{fast}} = \\beta_{\\text{fast}} m_{t-1}^{\\text{fast}} + (1 - \\beta_{\\text{fast}}) g_t \\\\ |
| 128 | + m_t^{\\text{slow}} = \\beta_{\\text{slow}} m_{t-1}^{\\text{slow}} + (1 - \\beta_{\\text{slow}}) g_t \\\\ |
| 129 | + v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\ |
| 130 | + \\hat{m}_t^{\\text{fast}} = \\frac{m_t^{\\text{fast}}}{1 - \\beta_{\\text{fast}}^t} \\\\ |
| 131 | + \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\ |
| 132 | + \\text{update} = \\frac{\\hat{m}_t^{\\text{fast}} + \\alpha m_t^{\\text{slow}}}{\\sqrt{\\hat{v}_t} + \\epsilon} |
| 133 | +
|
| 134 | + Args: |
| 135 | + grad: The gradient tensor. |
| 136 | + exp_avg_fast: The accumulated first moment of the gradient with fast time constant. |
| 137 | + exp_avg_slow: The accumulated first moment of the gradient with slow time constant. |
| 138 | + exp_avg_sq: The accumulated second moment of the gradient. |
| 139 | + num_beta_slow_warmup_steps: Number of warmup steps used to increase beta_slow |
| 140 | + num_alpha_warmup_steps: Number of warmup steps used to increase alpha |
| 141 | + betas: The EMA beta coefficients for the Adam update. |
| 142 | + step: The current step of the optimizer, used to compute the bias correction terms. |
| 143 | + eps: The epsilon for the Adam second moment update. |
| 144 | + correct_bias: Whether to correct the bias of the AdEMAMix update. |
| 145 | + alpha: Coeficient for mixing the current gradient and EMA, the final value to use in case of scheduling. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + The AdEMAMix update. |
| 149 | + """ |
| 150 | + beta_fast, beta2, beta_slow_final = betas |
| 151 | + |
| 152 | + if num_alpha_warmup_steps is not None: |
| 153 | + alpha = _linear_warmup_scheduler(step, alpha_end=alpha, alpha_start=0, num_warmup_steps=num_alpha_warmup_steps) |
| 154 | + else: |
| 155 | + alpha = alpha |
| 156 | + |
| 157 | + # Compute beta_slow based on scheduler with half-life linear warmup |
| 158 | + # beta_start is usually set to beta_fast |
| 159 | + if num_beta_slow_warmup_steps is not None: |
| 160 | + beta_slow = _linear_half_life_warmup_scheduler( |
| 161 | + step, beta_end=beta_slow_final, beta_start=beta_fast, num_warmup_steps=num_beta_slow_warmup_steps |
| 162 | + ) |
| 163 | + else: |
| 164 | + beta_slow = beta_slow_final |
| 165 | + |
| 166 | + if correct_bias: |
| 167 | + bias_correction1 = 1 - beta_fast**step |
| 168 | + bias_correction2 = 1 - beta2**step |
| 169 | + else: |
| 170 | + bias_correction1 = 1 |
| 171 | + bias_correction2 = 1 |
| 172 | + |
| 173 | + # Decay the fast first moment, slow first moment and second moment with an exponential moving average |
| 174 | + if beta_fast != 0.0: |
| 175 | + exp_avg_fast.lerp_(grad, 1 - beta_fast) |
| 176 | + else: |
| 177 | + exp_avg_fast = grad |
| 178 | + exp_avg_slow.lerp_(grad, 1 - beta_slow) |
| 179 | + exp_avg_sq.lerp_(grad.square(), 1 - beta2) |
| 180 | + |
| 181 | + # Correct biases of fast moment and adam second moment, slow moment is not corrected |
| 182 | + fast_moment = exp_avg_fast / bias_correction1 |
| 183 | + adam_second_moment = exp_avg_sq / bias_correction2 |
| 184 | + adam_second_moment = adam_second_moment.sqrt() + eps |
| 185 | + |
| 186 | + return (fast_moment + alpha * exp_avg_slow) / adam_second_moment |
| 187 | + |
| 188 | + |
| 189 | +def _half_life_steps(beta: float, eps: float = 1e-8) -> float: |
| 190 | + """Function that maps beta to the number of steps to reach 0.5. |
| 191 | +
|
| 192 | + Equation: |
| 193 | + f(beta) = log(0.5) / log(beta + eps) - 1 |
| 194 | +
|
| 195 | + Args: |
| 196 | + beta: The beta parameter. |
| 197 | + eps: A small constant to avoid division by zero. |
| 198 | +
|
| 199 | + Returns: |
| 200 | + The number of steps to reach 0.5. |
| 201 | + """ |
| 202 | + return math.log(0.5) / math.log(beta + eps) - 1 |
| 203 | + |
| 204 | + |
| 205 | +def _inverse_half_life_beta(t: float) -> float: |
| 206 | + """Maps number of steps to reach 0.5 to beta. |
| 207 | +
|
| 208 | + Equation: |
| 209 | + f_inv(t) = 0.5^(1 / (t + 1)) |
| 210 | +
|
| 211 | + Args: |
| 212 | + t: The number of steps to reach 0.5. |
| 213 | +
|
| 214 | + Returns: |
| 215 | + The beta parameter. |
| 216 | + """ |
| 217 | + return math.pow(0.5, 1 / (t + 1)) |
| 218 | + |
| 219 | + |
| 220 | +def _linear_half_life_warmup_scheduler( |
| 221 | + step: int, beta_end: float, beta_start: float = 0, num_warmup_steps: int = 1 |
| 222 | +) -> float: |
| 223 | + """Half-life linear warmup scheduler for the beta parameter. |
| 224 | +
|
| 225 | + Equation: |
| 226 | + beta = f_inv((1 - step / num_warmup_steps) * f(beta_start) + (step / num_warmup_steps) * f(beta_end)) |
| 227 | +
|
| 228 | +
|
| 229 | + Args: |
| 230 | + step: The current step of the optimizer. |
| 231 | + beta_end: The final value of the beta parameter. |
| 232 | + beta_start: The initial value of the beta parameter. |
| 233 | + num_warmup_steps: The number of warmup steps. |
| 234 | +
|
| 235 | + Returns: |
| 236 | + The value of the beta parameter at the current step. |
| 237 | + """ |
| 238 | + |
| 239 | + if step < num_warmup_steps: |
| 240 | + a = step / float(num_warmup_steps) |
| 241 | + return _inverse_half_life_beta((1.0 - a) * _half_life_steps(beta_start) + a * _half_life_steps(beta_end)) |
| 242 | + return beta_end |
| 243 | + |
| 244 | + |
| 245 | +def _linear_warmup_scheduler(step: int, alpha_end: float, alpha_start: float = 0, num_warmup_steps: int = 1) -> float: |
| 246 | + """Linear warmup scheduler for the alpha parameter. |
| 247 | +
|
| 248 | + Equation: |
| 249 | + alpha = (1 - step / num_warmup_steps) * alpha_start + (step / num_warmup_steps) * alpha_end |
| 250 | +
|
| 251 | + Args: |
| 252 | + step: The current step of the optimizer. |
| 253 | + alpha_end: The final value of the alpha parameter. |
| 254 | + alpha_start: The initial value of the alpha parameter. |
| 255 | + num_warmup_steps: The number of warmup steps. |
| 256 | +
|
| 257 | + Returns: |
| 258 | + The value of the alpha parameter at the current step. |
| 259 | + """ |
| 260 | + if step < num_warmup_steps: |
| 261 | + a = step / float(num_warmup_steps) |
| 262 | + return (1.0 - a) * alpha_start + a * alpha_end |
| 263 | + return alpha_end |
0 commit comments