Skip to content
220 changes: 99 additions & 121 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class SOAP(optim.Optimizer):
or a callable function that takes the current step as input and returns the frequency.
adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)
precondition_1d: Whether to precondition 1D gradients (like biases).
trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix
normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer
correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations
use_eigh: Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis.
Expand All @@ -83,23 +81,23 @@ class SOAP(optim.Optimizer):
More steps can lead to better convergence but increased computation time.
max_update_rms: Clip the update RMS to this value (0 means no clipping).
use_kl_shampoo: Whether to use KL-Shampoo correction.
correct_shampoo_beta_bias: Whether to correct shampoo beta bias. Decoupled it from correct_bias for
testability because reference implementation of Soap doesn't bias correct shampoo beta.
"""

def __init__(
self,
params: ParamsT,
lr: float = 3e-3,
betas: Tuple[float, float] = (0.95, 0.95),
lr: float,
betas: Tuple[float, float] = (0.9, 0.95),
shampoo_beta: float = 0.95,
eps: float = 1e-8,
weight_decay: float = 0.01,
use_decoupled_wd: bool = True,
use_nesterov: bool = False,
precondition_frequency: Union[int, Callable[[int], int]] = 1,
adam_warmup_steps: int = 1,
adam_warmup_steps: int = 0,
precondition_1d: bool = False,
trace_normalization: bool = False,
normalize_preconditioned_grads: bool = False,
correct_bias: bool = True,
fp32_matmul_prec: str = "high",
use_eigh: bool = False,
Expand All @@ -109,12 +107,11 @@ def __init__(
power_iter_steps: int = 1,
max_update_rms: float = 0.0,
use_kl_shampoo: bool = False,
correct_shampoo_beta_bias: bool | None = None,
) -> None:
self.precondition_frequency = precondition_frequency
self.adam_warmup_steps = adam_warmup_steps
self.precondition_1d = precondition_1d
self.trace_normalization = trace_normalization
self.normalize_preconditioned_grads = normalize_preconditioned_grads
self.use_nesterov = use_nesterov
self.correct_bias = correct_bias
self.use_decoupled_wd = use_decoupled_wd
Expand All @@ -126,6 +123,10 @@ def __init__(
self.power_iter_steps = power_iter_steps
self.max_update_rms = max_update_rms
self.use_kl_shampoo = use_kl_shampoo
if correct_shampoo_beta_bias is not None:
self.correct_shampoo_beta_bias = correct_shampoo_beta_bias
else:
self.correct_shampoo_beta_bias = correct_bias

defaults = {
"lr": lr,
Expand Down Expand Up @@ -160,155 +161,132 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
if "step" not in state:
state["step"] = 0

# State initialization
# (TODO @mkhona): Better way to check state initialization - use state initializer?
if "exp_avg" not in state:
# NOTE: The upstream PyTorch implementations increment the step counter in the middle of the loop
# to be used in bias correction. But this is confusing and error prone if anything else needs to use
# the step counter.
# We decided to follow Python and C convention to increment the step counter at the end of the loop.
# An explicitly named 1-based iteration/step counter is created for bias correction and other terms
# in the math equation that needs 1-based iteration count.
curr_iter_1_based = state["step"] + 1

# TODO(Mkhona): Improve initialization handling.
# - More protective checks can be added to avoid potential issues with checkpointing.
# - Initializing zero buffers can also be avoided.
if state["step"] == 0:
assert all(key not in state for key in ["exp_avg", "exp_avg_sq", "GG"]), (
"exp_avg and exp_avg_sq and GG should not be initialized at step 0. "
"Some mismatch has been created likely in checkpointing"
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)

if "Q" not in state:
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
# Initialize kronecker factor matrices
state["GG"] = init_kronecker_factors(
grad,
precondition_1d=self.precondition_1d,
)

# Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
# because it needs access to state and group
kronecker_factor_update_fn = partial(update_kronecker_factors, precondition_1d=self.precondition_1d)
if self.use_kl_shampoo:
if not self.use_kl_shampoo:
kronecker_factor_update_fn = partial(
update_kronecker_factors,
precondition_1d=self.precondition_1d,
)
else:
if "Q" not in state:
assert state["step"] == 0, (
f"Q should already be initialized at step {state['step']}, Some mismatch has been created "
"likely in checkpointing"
)
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
kronecker_factor_update_fn = partial(
update_kronecker_factors_kl_shampoo,
eigenbasis_list=state["Q"],
eps=group["eps"],
)

# Initialize kronecker factor matrices
if "GG" not in state:
state["GG"] = init_kronecker_factors(
grad,
precondition_1d=self.precondition_1d,
)
shampoo_beta = group["shampoo_beta"]
if self.correct_shampoo_beta_bias:
shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based)

# Update preconditioner matrices with gradient statistics,
# do not use shampoo_beta for EMA at first step
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
kronecker_factor_update_fn(
kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"]
)
torch.cuda.nvtx.range_push("update_kronecker_factors")
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
kronecker_factor_update_fn(kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=shampoo_beta)
torch.cuda.nvtx.range_pop()

# Increment step counter
state["step"] += 1
# After the adam_warmup_steps are completed , update eigenbases at precondition_frequency steps
torch.cuda.nvtx.range_push("Update eigen basis")
if _is_eigenbasis_update_step(
state["step"],
self.adam_warmup_steps,
self.precondition_frequency,
):
# Always use eigh for the first eigenbasis update
use_eigh = self.use_eigh if state["step"] != self.adam_warmup_steps else True

with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum(
kronecker_factor_list=state["GG"],
eigenbasis_list=state.get("Q", None),
exp_avg_sq=state["exp_avg_sq"],
momentum=state["exp_avg"],
use_eigh=use_eigh,
use_adaptive_criteria=self.use_adaptive_criteria,
adaptive_update_tolerance=self.adaptive_update_tolerance,
power_iter_steps=self.power_iter_steps,
)
torch.cuda.nvtx.range_pop()

# Apply weight decay
if group["weight_decay"] > 0.0:
if self.use_decoupled_wd:
# Apply decoupled weight decay
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
else:
# add l2 regularization before preconditioning (i.e. like adding a squared loss term)
grad += group["weight_decay"] * p

# Projecting gradients to the eigenbases of Shampoo's preconditioner
grad_projected = grad
# Project gradients to the eigenbases of Shampoo's preconditioner
torch.cuda.nvtx.range_push("precondition")
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
grad_projected = precondition(
grad=grad,
eigenbasis_list=state["Q"],
dims=[[0], [0]],
)
if state["step"] >= self.adam_warmup_steps:
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
grad_projected = precondition(
grad=grad,
eigenbasis_list=state["Q"],
dims=[[0], [0]],
)
torch.cuda.nvtx.range_pop()

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

# Calculate the Adam update for the projected gradient tensor
torch.cuda.nvtx.range_push("calculate_adam_update")
adam_update = calculate_adam_update(
grad_projected,
exp_avg,
exp_avg_sq,
state["exp_avg"],
state["exp_avg_sq"],
group["betas"],
self.correct_bias,
self.use_nesterov,
state["step"],
curr_iter_1_based, # 1-based iteration index is used for bias correction
group["eps"],
)
step_size = group["lr"]
torch.cuda.nvtx.range_pop()

# Projecting back the preconditioned (by ADAM) exponential moving average of gradients
torch.cuda.nvtx.range_push("precondition")
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
norm_precond_grad = precondition(
grad=adam_update,
eigenbasis_list=state["Q"],
dims=[[0], [1]],
)
torch.cuda.nvtx.range_pop()

if self.trace_normalization:
if state["GG"][0].numel() > 0:
trace_normalization = 1 / torch.sqrt(torch.trace(state["GG"][0]))
norm_precond_grad = norm_precond_grad / trace_normalization

if self.normalize_preconditioned_grads:
norm_precond_grad = norm_precond_grad / (1e-30 + torch.mean(norm_precond_grad**2) ** 0.5)

# Clip the update RMS to a maximum value
_clip_update_rms_in_place(norm_precond_grad, self.max_update_rms)

torch.cuda.nvtx.range_push("weight update")
p.add_(norm_precond_grad, alpha=-step_size)
torch.cuda.nvtx.range_pop()

# Update kronecker factor matrices with gradient statistics
shampoo_beta = group["shampoo_beta"] if group["shampoo_beta"] >= 0 else group["betas"][1]
if self.correct_bias:
# step size correction for shampoo kronecker factors EMA
shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta ** (state["step"] + 1))

torch.cuda.nvtx.range_push("update_kronecker_factors")
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
kronecker_factor_update_fn(
kronecker_factor_list=state["GG"],
grad=grad,
shampoo_beta=0.0,
)
torch.cuda.nvtx.range_pop()

# If current step is the last step to skip preconditioning, initialize eigenbases and
# end first order warmup
if state["step"] == self.adam_warmup_steps:
# Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition
state["Q"] = get_eigenbasis_eigh(state["GG"])
# rotate momentum to the new eigenbasis
if state["step"] >= self.adam_warmup_steps:
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
state["exp_avg"] = precondition(
grad=state["exp_avg"],
eigenbasis_list=state["Q"],
dims=[[0], [0]],
)
continue

# After the adam_warmup_steps are completed.
# Update eigenbases at precondition_frequency steps
torch.cuda.nvtx.range_push("Update eigen basis")
if _is_eigenbasis_update_step(
state["step"],
self.adam_warmup_steps,
self.precondition_frequency,
):
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum(
kronecker_factor_list=state["GG"],
eigenbasis_list=state["Q"],
exp_avg_sq=state["exp_avg_sq"],
momentum=state["exp_avg"],
use_eigh=self.use_eigh,
use_adaptive_criteria=self.use_adaptive_criteria,
adaptive_update_tolerance=self.adaptive_update_tolerance,
power_iter_steps=self.power_iter_steps,
precond_update = precondition(
grad=adam_update,
eigenbasis_list=state.get("Q", None),
dims=[[0], [1]],
)
else:
precond_update = adam_update
torch.cuda.nvtx.range_pop()

_clip_update_rms_in_place(precond_update, self.max_update_rms)
p.add_(precond_update, alpha=-group["lr"])

state["step"] += 1

return loss


Expand Down Expand Up @@ -581,7 +559,7 @@ def update_eigenbasis_and_momentum(
@torch.compile # type: ignore[misc]
def precondition(
grad: torch.Tensor,
eigenbasis_list: Optional[List[torch.Tensor]],
eigenbasis_list: Optional[List[torch.Tensor]] = None,
dims: Optional[List[List[int]]] = None,
) -> torch.Tensor:
"""Projects the gradient to and from the eigenbases of the kronecker factor matrices.
Expand All @@ -607,7 +585,7 @@ def precondition(
# Pick contraction dims to project to the eigenbasis
dims = [[0], [0]]

if not eigenbasis_list:
if eigenbasis_list is None:
# If eigenbases are not provided, return the gradient without any preconditioning
return grad

Expand Down Expand Up @@ -653,7 +631,7 @@ def _is_eigenbasis_update_step(


@torch.compile # type: ignore[misc]
def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float = 1.0, eps: float = 1e-12) -> None:
def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7) -> None:
"""Clip the update root mean square (RMS) to a maximum value, in place.

Do not clip if max_rms is 0.
Expand Down
2 changes: 1 addition & 1 deletion emerging_optimizers/utils/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def eigh_with_fallback(

# Add small identity for numerical stability
eye = torch.eye(
x.shape[-1],
x.shape[0],
device=x.device,
dtype=x.dtype,
)
Expand Down
9 changes: 1 addition & 8 deletions tests/soap_mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,10 @@ def forward(self, x):
"eps": 1e-8,
"precondition_1d": True, # Enable preconditioning for bias vectors
"precondition_frequency": 1, # Update preconditioner every step for testing
"trace_normalization": True,
"shampoo_beta": 0.9, # Slightly more aggressive moving average
"fp32_matmul_prec": "high",
"qr_fp32_matmul_prec": "high",
"use_adaptive_criteria": False,
"power_iter_steps": 1,
"use_nesterov": True,
"skip_preconditioning_steps": 0,
}


Expand Down Expand Up @@ -106,15 +102,12 @@ def main() -> None:
# Initialize optimizers
optimizer_soap = SOAP(
model_soap.parameters(),
lr=9.0 * config["lr"],
lr=2.05 * config["lr"],
weight_decay=config["weight_decay"],
betas=(config["adam_beta1"], config["adam_beta2"]),
eps=config["eps"],
precondition_frequency=config["precondition_frequency"],
trace_normalization=config["trace_normalization"],
shampoo_beta=config["shampoo_beta"],
precondition_1d=config["precondition_1d"],
use_nesterov=config["use_nesterov"],
fp32_matmul_prec=config["fp32_matmul_prec"],
qr_fp32_matmul_prec=config["qr_fp32_matmul_prec"],
use_adaptive_criteria=config["use_adaptive_criteria"],
Expand Down
Loading