Skip to content

Commit 16f8399

Browse files
authored
Update algorithmic choices for soap and clean up code (#68)
* Update is not skipped the first step out of adam warmup * Current gradient is used in preconditioning * First Kronecker factor is updated with Shamoo beta * Explicitly differentiate 0-based and 1-based step count for different purposes. Signed-off-by: Hao Wu <[email protected]>
1 parent 1bb3f72 commit 16f8399

File tree

6 files changed

+664
-149
lines changed

6 files changed

+664
-149
lines changed

emerging_optimizers/soap/soap.py

Lines changed: 99 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ class SOAP(optim.Optimizer):
6969
or a callable function that takes the current step as input and returns the frequency.
7070
adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)
7171
precondition_1d: Whether to precondition 1D gradients (like biases).
72-
trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix
73-
normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer
7472
correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA
7573
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations
7674
use_eigh: Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis.
@@ -83,23 +81,23 @@ class SOAP(optim.Optimizer):
8381
More steps can lead to better convergence but increased computation time.
8482
max_update_rms: Clip the update RMS to this value (0 means no clipping).
8583
use_kl_shampoo: Whether to use KL-Shampoo correction.
84+
correct_shampoo_beta_bias: Whether to correct shampoo beta bias. Decoupled it from correct_bias for
85+
testability because reference implementation of Soap doesn't bias correct shampoo beta.
8686
"""
8787

8888
def __init__(
8989
self,
9090
params: ParamsT,
91-
lr: float = 3e-3,
92-
betas: Tuple[float, float] = (0.95, 0.95),
91+
lr: float,
92+
betas: Tuple[float, float] = (0.9, 0.95),
9393
shampoo_beta: float = 0.95,
9494
eps: float = 1e-8,
9595
weight_decay: float = 0.01,
9696
use_decoupled_wd: bool = True,
9797
use_nesterov: bool = False,
9898
precondition_frequency: Union[int, Callable[[int], int]] = 1,
99-
adam_warmup_steps: int = 1,
99+
adam_warmup_steps: int = 0,
100100
precondition_1d: bool = False,
101-
trace_normalization: bool = False,
102-
normalize_preconditioned_grads: bool = False,
103101
correct_bias: bool = True,
104102
fp32_matmul_prec: str = "high",
105103
use_eigh: bool = False,
@@ -109,12 +107,11 @@ def __init__(
109107
power_iter_steps: int = 1,
110108
max_update_rms: float = 0.0,
111109
use_kl_shampoo: bool = False,
110+
correct_shampoo_beta_bias: bool | None = None,
112111
) -> None:
113112
self.precondition_frequency = precondition_frequency
114113
self.adam_warmup_steps = adam_warmup_steps
115114
self.precondition_1d = precondition_1d
116-
self.trace_normalization = trace_normalization
117-
self.normalize_preconditioned_grads = normalize_preconditioned_grads
118115
self.use_nesterov = use_nesterov
119116
self.correct_bias = correct_bias
120117
self.use_decoupled_wd = use_decoupled_wd
@@ -126,6 +123,10 @@ def __init__(
126123
self.power_iter_steps = power_iter_steps
127124
self.max_update_rms = max_update_rms
128125
self.use_kl_shampoo = use_kl_shampoo
126+
if correct_shampoo_beta_bias is not None:
127+
self.correct_shampoo_beta_bias = correct_shampoo_beta_bias
128+
else:
129+
self.correct_shampoo_beta_bias = correct_bias
129130

130131
defaults = {
131132
"lr": lr,
@@ -160,155 +161,132 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
160161
if "step" not in state:
161162
state["step"] = 0
162163

163-
# State initialization
164-
# (TODO @mkhona): Better way to check state initialization - use state initializer?
165-
if "exp_avg" not in state:
164+
# NOTE: The upstream PyTorch implementations increment the step counter in the middle of the loop
165+
# to be used in bias correction. But this is confusing and error prone if anything else needs to use
166+
# the step counter.
167+
# We decided to follow Python and C convention to increment the step counter at the end of the loop.
168+
# An explicitly named 1-based iteration/step counter is created for bias correction and other terms
169+
# in the math equation that needs 1-based iteration count.
170+
curr_iter_1_based = state["step"] + 1
171+
172+
# TODO(Mkhona): Improve initialization handling.
173+
# - More protective checks can be added to avoid potential issues with checkpointing.
174+
# - Initializing zero buffers can also be avoided.
175+
if state["step"] == 0:
176+
assert all(key not in state for key in ["exp_avg", "exp_avg_sq", "GG"]), (
177+
"exp_avg and exp_avg_sq and GG should not be initialized at step 0. "
178+
"Some mismatch has been created likely in checkpointing"
179+
)
166180
# Exponential moving average of gradient values
167181
state["exp_avg"] = torch.zeros_like(grad)
168182
# Exponential moving average of squared gradient values
169183
state["exp_avg_sq"] = torch.zeros_like(grad)
170-
171-
if "Q" not in state:
172-
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
184+
# Initialize kronecker factor matrices
185+
state["GG"] = init_kronecker_factors(
186+
grad,
187+
precondition_1d=self.precondition_1d,
188+
)
173189

174190
# Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
175191
# because it needs access to state and group
176-
kronecker_factor_update_fn = partial(update_kronecker_factors, precondition_1d=self.precondition_1d)
177-
if self.use_kl_shampoo:
192+
if not self.use_kl_shampoo:
193+
kronecker_factor_update_fn = partial(
194+
update_kronecker_factors,
195+
precondition_1d=self.precondition_1d,
196+
)
197+
else:
198+
if "Q" not in state:
199+
assert state["step"] == 0, (
200+
f"Q should already be initialized at step {state['step']}, Some mismatch has been created "
201+
"likely in checkpointing"
202+
)
203+
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
178204
kronecker_factor_update_fn = partial(
179205
update_kronecker_factors_kl_shampoo,
180206
eigenbasis_list=state["Q"],
181207
eps=group["eps"],
182208
)
183209

184-
# Initialize kronecker factor matrices
185-
if "GG" not in state:
186-
state["GG"] = init_kronecker_factors(
187-
grad,
188-
precondition_1d=self.precondition_1d,
189-
)
210+
shampoo_beta = group["shampoo_beta"]
211+
if self.correct_shampoo_beta_bias:
212+
shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based)
190213

191-
# Update preconditioner matrices with gradient statistics,
192-
# do not use shampoo_beta for EMA at first step
193-
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
194-
kronecker_factor_update_fn(
195-
kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"]
196-
)
214+
torch.cuda.nvtx.range_push("update_kronecker_factors")
215+
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
216+
kronecker_factor_update_fn(kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=shampoo_beta)
217+
torch.cuda.nvtx.range_pop()
197218

198-
# Increment step counter
199-
state["step"] += 1
219+
# After the adam_warmup_steps are completed , update eigenbases at precondition_frequency steps
220+
torch.cuda.nvtx.range_push("Update eigen basis")
221+
if _is_eigenbasis_update_step(
222+
state["step"],
223+
self.adam_warmup_steps,
224+
self.precondition_frequency,
225+
):
226+
# Always use eigh for the first eigenbasis update
227+
use_eigh = self.use_eigh if state["step"] != self.adam_warmup_steps else True
228+
229+
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
230+
state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum(
231+
kronecker_factor_list=state["GG"],
232+
eigenbasis_list=state.get("Q", None),
233+
exp_avg_sq=state["exp_avg_sq"],
234+
momentum=state["exp_avg"],
235+
use_eigh=use_eigh,
236+
use_adaptive_criteria=self.use_adaptive_criteria,
237+
adaptive_update_tolerance=self.adaptive_update_tolerance,
238+
power_iter_steps=self.power_iter_steps,
239+
)
240+
torch.cuda.nvtx.range_pop()
200241

201-
# Apply weight decay
202242
if group["weight_decay"] > 0.0:
203243
if self.use_decoupled_wd:
204-
# Apply decoupled weight decay
205244
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
206245
else:
207-
# add l2 regularization before preconditioning (i.e. like adding a squared loss term)
208246
grad += group["weight_decay"] * p
209247

210-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
248+
grad_projected = grad
249+
# Project gradients to the eigenbases of Shampoo's preconditioner
211250
torch.cuda.nvtx.range_push("precondition")
212-
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
213-
grad_projected = precondition(
214-
grad=grad,
215-
eigenbasis_list=state["Q"],
216-
dims=[[0], [0]],
217-
)
251+
if state["step"] >= self.adam_warmup_steps:
252+
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
253+
grad_projected = precondition(
254+
grad=grad,
255+
eigenbasis_list=state["Q"],
256+
dims=[[0], [0]],
257+
)
218258
torch.cuda.nvtx.range_pop()
219259

220-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
221-
222260
# Calculate the Adam update for the projected gradient tensor
223-
torch.cuda.nvtx.range_push("calculate_adam_update")
224261
adam_update = calculate_adam_update(
225262
grad_projected,
226-
exp_avg,
227-
exp_avg_sq,
263+
state["exp_avg"],
264+
state["exp_avg_sq"],
228265
group["betas"],
229266
self.correct_bias,
230267
self.use_nesterov,
231-
state["step"],
268+
curr_iter_1_based, # 1-based iteration index is used for bias correction
232269
group["eps"],
233270
)
234-
step_size = group["lr"]
235-
torch.cuda.nvtx.range_pop()
236271

237272
# Projecting back the preconditioned (by ADAM) exponential moving average of gradients
238273
torch.cuda.nvtx.range_push("precondition")
239-
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
240-
norm_precond_grad = precondition(
241-
grad=adam_update,
242-
eigenbasis_list=state["Q"],
243-
dims=[[0], [1]],
244-
)
245-
torch.cuda.nvtx.range_pop()
246-
247-
if self.trace_normalization:
248-
if state["GG"][0].numel() > 0:
249-
trace_normalization = 1 / torch.sqrt(torch.trace(state["GG"][0]))
250-
norm_precond_grad = norm_precond_grad / trace_normalization
251-
252-
if self.normalize_preconditioned_grads:
253-
norm_precond_grad = norm_precond_grad / (1e-30 + torch.mean(norm_precond_grad**2) ** 0.5)
254-
255-
# Clip the update RMS to a maximum value
256-
_clip_update_rms_in_place(norm_precond_grad, self.max_update_rms)
257-
258-
torch.cuda.nvtx.range_push("weight update")
259-
p.add_(norm_precond_grad, alpha=-step_size)
260-
torch.cuda.nvtx.range_pop()
261-
262-
# Update kronecker factor matrices with gradient statistics
263-
shampoo_beta = group["shampoo_beta"] if group["shampoo_beta"] >= 0 else group["betas"][1]
264-
if self.correct_bias:
265-
# step size correction for shampoo kronecker factors EMA
266-
shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta ** (state["step"] + 1))
267-
268-
torch.cuda.nvtx.range_push("update_kronecker_factors")
269-
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
270-
kronecker_factor_update_fn(
271-
kronecker_factor_list=state["GG"],
272-
grad=grad,
273-
shampoo_beta=0.0,
274-
)
275-
torch.cuda.nvtx.range_pop()
276-
277-
# If current step is the last step to skip preconditioning, initialize eigenbases and
278-
# end first order warmup
279-
if state["step"] == self.adam_warmup_steps:
280-
# Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition
281-
state["Q"] = get_eigenbasis_eigh(state["GG"])
282-
# rotate momentum to the new eigenbasis
274+
if state["step"] >= self.adam_warmup_steps:
283275
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
284-
state["exp_avg"] = precondition(
285-
grad=state["exp_avg"],
286-
eigenbasis_list=state["Q"],
287-
dims=[[0], [0]],
288-
)
289-
continue
290-
291-
# After the adam_warmup_steps are completed.
292-
# Update eigenbases at precondition_frequency steps
293-
torch.cuda.nvtx.range_push("Update eigen basis")
294-
if _is_eigenbasis_update_step(
295-
state["step"],
296-
self.adam_warmup_steps,
297-
self.precondition_frequency,
298-
):
299-
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
300-
state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum(
301-
kronecker_factor_list=state["GG"],
302-
eigenbasis_list=state["Q"],
303-
exp_avg_sq=state["exp_avg_sq"],
304-
momentum=state["exp_avg"],
305-
use_eigh=self.use_eigh,
306-
use_adaptive_criteria=self.use_adaptive_criteria,
307-
adaptive_update_tolerance=self.adaptive_update_tolerance,
308-
power_iter_steps=self.power_iter_steps,
276+
precond_update = precondition(
277+
grad=adam_update,
278+
eigenbasis_list=state.get("Q", None),
279+
dims=[[0], [1]],
309280
)
281+
else:
282+
precond_update = adam_update
310283
torch.cuda.nvtx.range_pop()
311284

285+
_clip_update_rms_in_place(precond_update, self.max_update_rms)
286+
p.add_(precond_update, alpha=-group["lr"])
287+
288+
state["step"] += 1
289+
312290
return loss
313291

314292

@@ -581,7 +559,7 @@ def update_eigenbasis_and_momentum(
581559
@torch.compile # type: ignore[misc]
582560
def precondition(
583561
grad: torch.Tensor,
584-
eigenbasis_list: Optional[List[torch.Tensor]],
562+
eigenbasis_list: Optional[List[torch.Tensor]] = None,
585563
dims: Optional[List[List[int]]] = None,
586564
) -> torch.Tensor:
587565
"""Projects the gradient to and from the eigenbases of the kronecker factor matrices.
@@ -607,7 +585,7 @@ def precondition(
607585
# Pick contraction dims to project to the eigenbasis
608586
dims = [[0], [0]]
609587

610-
if not eigenbasis_list:
588+
if eigenbasis_list is None:
611589
# If eigenbases are not provided, return the gradient without any preconditioning
612590
return grad
613591

@@ -653,7 +631,7 @@ def _is_eigenbasis_update_step(
653631

654632

655633
@torch.compile # type: ignore[misc]
656-
def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float = 1.0, eps: float = 1e-12) -> None:
634+
def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7) -> None:
657635
"""Clip the update root mean square (RMS) to a maximum value, in place.
658636
659637
Do not clip if max_rms is 0.

emerging_optimizers/utils/eig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def eigh_with_fallback(
7272

7373
# Add small identity for numerical stability
7474
eye = torch.eye(
75-
x.shape[-1],
75+
x.shape[0],
7676
device=x.device,
7777
dtype=x.dtype,
7878
)

tests/soap_mnist_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,10 @@ def forward(self, x):
4747
"eps": 1e-8,
4848
"precondition_1d": True, # Enable preconditioning for bias vectors
4949
"precondition_frequency": 1, # Update preconditioner every step for testing
50-
"trace_normalization": True,
51-
"shampoo_beta": 0.9, # Slightly more aggressive moving average
5250
"fp32_matmul_prec": "high",
5351
"qr_fp32_matmul_prec": "high",
5452
"use_adaptive_criteria": False,
5553
"power_iter_steps": 1,
56-
"use_nesterov": True,
57-
"skip_preconditioning_steps": 0,
5854
}
5955

6056

@@ -106,15 +102,12 @@ def main() -> None:
106102
# Initialize optimizers
107103
optimizer_soap = SOAP(
108104
model_soap.parameters(),
109-
lr=9.0 * config["lr"],
105+
lr=2.05 * config["lr"],
110106
weight_decay=config["weight_decay"],
111107
betas=(config["adam_beta1"], config["adam_beta2"]),
112108
eps=config["eps"],
113109
precondition_frequency=config["precondition_frequency"],
114-
trace_normalization=config["trace_normalization"],
115-
shampoo_beta=config["shampoo_beta"],
116110
precondition_1d=config["precondition_1d"],
117-
use_nesterov=config["use_nesterov"],
118111
fp32_matmul_prec=config["fp32_matmul_prec"],
119112
qr_fp32_matmul_prec=config["qr_fp32_matmul_prec"],
120113
use_adaptive_criteria=config["use_adaptive_criteria"],

0 commit comments

Comments
 (0)