Skip to content

Commit 2fdbc22

Browse files
Support for SOAP optimizer (#16)
* added library for SOAP Signed-off-by: mikail <[email protected]>
1 parent 812dce6 commit 2fdbc22

File tree

18 files changed

+3099
-0
lines changed

18 files changed

+3099
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from .adam import calculate_adam_update
16+
from .ademamix import calculate_ademamix_update, calculate_sim_ademamix_update
17+
from .laprop import calculate_laprop_update
18+
from .lion import calculate_lion_update
19+
from .signum import calculate_signum_update
20+
21+
22+
__all__ = [
23+
"calculate_adam_update",
24+
"calculate_sim_ademamix_update",
25+
"calculate_ademamix_update",
26+
"calculate_signum_update",
27+
"calculate_laprop_update",
28+
"calculate_lion_update",
29+
]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import Tuple
16+
17+
import torch
18+
19+
20+
__all__ = [
21+
"calculate_adam_update",
22+
]
23+
24+
25+
@torch.compile # type: ignore[misc]
26+
@torch.no_grad() # type: ignore[misc]
27+
def calculate_adam_update(
28+
grad: torch.Tensor,
29+
exp_avg: torch.Tensor,
30+
exp_avg_sq: torch.Tensor,
31+
betas: Tuple[float, float],
32+
correct_bias: bool,
33+
use_nesterov: bool,
34+
step: int,
35+
eps: float,
36+
) -> torch.Tensor:
37+
"""Performs the Adam update.
38+
39+
This function performs the computation of 1 step of Adam.
40+
41+
The update rule is as follows:
42+
43+
.. math::
44+
m_t = \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
45+
v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\
46+
\\hat{m}_t = \\frac{m_t}{1 - \\beta_1^t} \\\\
47+
\\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\
48+
\\text{update} = \\frac{\\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon} \\\\
49+
50+
Args:
51+
grad: The gradient tensor.
52+
exp_avg: The accumulated first moment of the gradient.
53+
exp_avg_sq: The accumulated second moment of the gradient.
54+
betas: The EMA beta coefficients for the Adam update.
55+
correct_bias: Whether to correct the bias of the Adam update.
56+
use_nesterov: Whether to use nesterov momentum.
57+
step: The current step of the optimizer, used to compute the bias correction terms.
58+
eps: The epsilon for the Adam second moment update.
59+
60+
Returns:
61+
The Adam-update.
62+
"""
63+
64+
beta1, beta2 = betas
65+
66+
# Decay the first and second moment running average coefficient
67+
exp_avg.lerp_(grad, 1 - beta1)
68+
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
69+
70+
# step size correction for optimizer states EMA
71+
bias_correction1 = 1.0
72+
bias_correction2 = 1.0
73+
if correct_bias:
74+
# step size correction for ADAM moments EMA
75+
bias_correction1 = 1.0 - beta1 ** (step)
76+
bias_correction2 = 1.0 - beta2 ** (step)
77+
78+
if use_nesterov:
79+
# Apply nesterov momentum correction, optionally with bias correction
80+
bias_correction_nesterov = (1 - beta1 ** (step + 1)) if correct_bias else 1.0
81+
momentum = beta1 * exp_avg / bias_correction_nesterov + (1 - beta1) * grad / bias_correction1
82+
else:
83+
# Use standard momentum, optionally with bias correction
84+
momentum = exp_avg / bias_correction1
85+
86+
# construct the denominator of the inner ADAM optimizer
87+
adam_second_moment = exp_avg_sq / bias_correction2
88+
adam_second_moment = adam_second_moment.sqrt() + eps
89+
return momentum / adam_second_moment
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import math
16+
from typing import Optional, Tuple
17+
18+
import torch
19+
20+
21+
__all__ = [
22+
"calculate_sim_ademamix_update",
23+
"calculate_ademamix_update",
24+
]
25+
26+
27+
@torch.compile # type: ignore[misc]
28+
@torch.no_grad() # type: ignore[misc]
29+
def calculate_sim_ademamix_update(
30+
grad: torch.Tensor,
31+
exp_avg: torch.Tensor,
32+
exp_avg_sq: torch.Tensor,
33+
num_beta_fast_warmup_steps: Optional[int],
34+
min_beta_fast: float,
35+
betas: Tuple[float, float],
36+
step: int,
37+
eps: float,
38+
correct_bias: bool,
39+
alpha: float = 2,
40+
) -> torch.Tensor:
41+
"""Performs simplified AdEMAMix update.
42+
43+
This function performs the computation of 1 step of simplified AdEMAMix.
44+
Based on https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py
45+
and https://arxiv.org/abs/2409.03137.
46+
47+
The update rule is as follows:
48+
49+
.. math::
50+
m_t = \\beta_{\\text{fast}} m_{t-1} + g_t \\\\
51+
v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\
52+
\\hat{m}_t = \\frac{m_t}{(1 - \\beta_{\\text{fast}}^t) / (1 - \\beta_{\\text{fast}})} \\\\
53+
\\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\
54+
\\text{update} = \\frac{\\alpha g_t + \\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon}
55+
56+
Args:
57+
grad: The gradient tensor.
58+
exp_avg: The accumulated first moment of the gradient.
59+
exp_avg_sq: The accumulated second moment of the gradient.
60+
num_beta_fast_warmup_steps: Number of warmup steps used to increase beta_fast
61+
min_beta_fast: The minimum beta_fast value used at initialization
62+
betas: The EMA beta coefficients for the Adam update.
63+
step: The current step of the optimizer, used to compute the bias correction terms.
64+
eps: The epsilon for the Adam second moment update.
65+
correct_bias: Whether to correct the bias of the AdEMAMix update.
66+
alpha: Coeficient for mixing the current gradient and EMA.
67+
68+
Returns:
69+
The simplified-AdEMAMix update.
70+
"""
71+
beta_fast_final, beta2 = betas
72+
73+
# Compute beta_fast based on scheduler
74+
if num_beta_fast_warmup_steps is not None:
75+
beta_fast = _linear_half_life_warmup_scheduler(
76+
step, beta_end=beta_fast_final, beta_start=min_beta_fast, num_warmup_steps=num_beta_fast_warmup_steps
77+
)
78+
else:
79+
beta_fast = beta_fast_final
80+
81+
# Decay the first moment "theory style": https://arxiv.org/abs/2502.02431
82+
exp_avg.mul_(beta_fast).add_(grad, alpha=1.0)
83+
84+
# Decay the second moment exponential moving average
85+
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
86+
87+
if correct_bias:
88+
# theory style bias correction
89+
bias_correction1 = (1 - beta_fast**step) / (1 - beta_fast)
90+
bias_correction2 = 1 - beta2**step
91+
else:
92+
bias_correction1 = 1
93+
bias_correction2 = 1
94+
95+
# step size correction for optimizer states EMA
96+
momentum = exp_avg / bias_correction1
97+
adam_second_moment = exp_avg_sq / bias_correction2
98+
adam_second_moment = adam_second_moment.sqrt() + eps
99+
100+
return (alpha * grad + momentum) / adam_second_moment
101+
102+
103+
@torch.compile # type: ignore[misc]
104+
@torch.no_grad() # type: ignore[misc]
105+
def calculate_ademamix_update(
106+
grad: torch.Tensor,
107+
exp_avg_fast: torch.Tensor,
108+
exp_avg_slow: torch.Tensor,
109+
exp_avg_sq: torch.Tensor,
110+
num_beta_slow_warmup_steps: Optional[int],
111+
num_alpha_warmup_steps: Optional[int],
112+
betas: Tuple[float, float, float],
113+
step: int,
114+
eps: float,
115+
correct_bias: bool,
116+
alpha: float = 2,
117+
) -> torch.Tensor:
118+
"""Performs AdEMAMix update.
119+
120+
This function performs the computation of 1 step of AdEMAMix.
121+
Based on https://github.com/apple/ml-ademamix/blob/main/pytorch/ademamix.py
122+
and https://arxiv.org/abs/2409.03137.
123+
124+
The update rule is as follows:
125+
126+
.. math::
127+
m_t^{\\text{fast}} = \\beta_{\\text{fast}} m_{t-1}^{\\text{fast}} + (1 - \\beta_{\\text{fast}}) g_t \\\\
128+
m_t^{\\text{slow}} = \\beta_{\\text{slow}} m_{t-1}^{\\text{slow}} + (1 - \\beta_{\\text{slow}}) g_t \\\\
129+
v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\
130+
\\hat{m}_t^{\\text{fast}} = \\frac{m_t^{\\text{fast}}}{1 - \\beta_{\\text{fast}}^t} \\\\
131+
\\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\
132+
\\text{update} = \\frac{\\hat{m}_t^{\\text{fast}} + \\alpha m_t^{\\text{slow}}}{\\sqrt{\\hat{v}_t} + \\epsilon}
133+
134+
Args:
135+
grad: The gradient tensor.
136+
exp_avg_fast: The accumulated first moment of the gradient with fast time constant.
137+
exp_avg_slow: The accumulated first moment of the gradient with slow time constant.
138+
exp_avg_sq: The accumulated second moment of the gradient.
139+
num_beta_slow_warmup_steps: Number of warmup steps used to increase beta_slow
140+
num_alpha_warmup_steps: Number of warmup steps used to increase alpha
141+
betas: The EMA beta coefficients for the Adam update.
142+
step: The current step of the optimizer, used to compute the bias correction terms.
143+
eps: The epsilon for the Adam second moment update.
144+
correct_bias: Whether to correct the bias of the AdEMAMix update.
145+
alpha: Coeficient for mixing the current gradient and EMA, the final value to use in case of scheduling.
146+
147+
Returns:
148+
The AdEMAMix update.
149+
"""
150+
beta_fast, beta2, beta_slow_final = betas
151+
152+
if num_alpha_warmup_steps is not None:
153+
alpha = _linear_warmup_scheduler(step, alpha_end=alpha, alpha_start=0, num_warmup_steps=num_alpha_warmup_steps)
154+
else:
155+
alpha = alpha
156+
157+
# Compute beta_slow based on scheduler with half-life linear warmup
158+
# beta_start is usually set to beta_fast
159+
if num_beta_slow_warmup_steps is not None:
160+
beta_slow = _linear_half_life_warmup_scheduler(
161+
step, beta_end=beta_slow_final, beta_start=beta_fast, num_warmup_steps=num_beta_slow_warmup_steps
162+
)
163+
else:
164+
beta_slow = beta_slow_final
165+
166+
if correct_bias:
167+
bias_correction1 = 1 - beta_fast**step
168+
bias_correction2 = 1 - beta2**step
169+
else:
170+
bias_correction1 = 1
171+
bias_correction2 = 1
172+
173+
# Decay the fast first moment, slow first moment and second moment with an exponential moving average
174+
if beta_fast != 0.0:
175+
exp_avg_fast.lerp_(grad, 1 - beta_fast)
176+
else:
177+
exp_avg_fast = grad
178+
exp_avg_slow.lerp_(grad, 1 - beta_slow)
179+
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
180+
181+
# Correct biases of fast moment and adam second moment, slow moment is not corrected
182+
fast_moment = exp_avg_fast / bias_correction1
183+
adam_second_moment = exp_avg_sq / bias_correction2
184+
adam_second_moment = adam_second_moment.sqrt() + eps
185+
186+
return (fast_moment + alpha * exp_avg_slow) / adam_second_moment
187+
188+
189+
def _half_life_steps(beta: float, eps: float = 1e-8) -> float:
190+
"""Function that maps beta to the number of steps to reach 0.5.
191+
192+
Equation:
193+
f(beta) = log(0.5) / log(beta + eps) - 1
194+
195+
Args:
196+
beta: The beta parameter.
197+
eps: A small constant to avoid division by zero.
198+
199+
Returns:
200+
The number of steps to reach 0.5.
201+
"""
202+
return math.log(0.5) / math.log(beta + eps) - 1
203+
204+
205+
def _inverse_half_life_beta(t: float) -> float:
206+
"""Maps number of steps to reach 0.5 to beta.
207+
208+
Equation:
209+
f_inv(t) = 0.5^(1 / (t + 1))
210+
211+
Args:
212+
t: The number of steps to reach 0.5.
213+
214+
Returns:
215+
The beta parameter.
216+
"""
217+
return math.pow(0.5, 1 / (t + 1))
218+
219+
220+
def _linear_half_life_warmup_scheduler(
221+
step: int, beta_end: float, beta_start: float = 0, num_warmup_steps: int = 1
222+
) -> float:
223+
"""Half-life linear warmup scheduler for the beta parameter.
224+
225+
Equation:
226+
beta = f_inv((1 - step / num_warmup_steps) * f(beta_start) + (step / num_warmup_steps) * f(beta_end))
227+
228+
229+
Args:
230+
step: The current step of the optimizer.
231+
beta_end: The final value of the beta parameter.
232+
beta_start: The initial value of the beta parameter.
233+
num_warmup_steps: The number of warmup steps.
234+
235+
Returns:
236+
The value of the beta parameter at the current step.
237+
"""
238+
239+
if step < num_warmup_steps:
240+
a = step / float(num_warmup_steps)
241+
return _inverse_half_life_beta((1.0 - a) * _half_life_steps(beta_start) + a * _half_life_steps(beta_end))
242+
return beta_end
243+
244+
245+
def _linear_warmup_scheduler(step: int, alpha_end: float, alpha_start: float = 0, num_warmup_steps: int = 1) -> float:
246+
"""Linear warmup scheduler for the alpha parameter.
247+
248+
Equation:
249+
alpha = (1 - step / num_warmup_steps) * alpha_start + (step / num_warmup_steps) * alpha_end
250+
251+
Args:
252+
step: The current step of the optimizer.
253+
alpha_end: The final value of the alpha parameter.
254+
alpha_start: The initial value of the alpha parameter.
255+
num_warmup_steps: The number of warmup steps.
256+
257+
Returns:
258+
The value of the alpha parameter at the current step.
259+
"""
260+
if step < num_warmup_steps:
261+
a = step / float(num_warmup_steps)
262+
return (1.0 - a) * alpha_start + a * alpha_end
263+
return alpha_end

0 commit comments

Comments
 (0)