diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 5a4de6e..2649c6a 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -69,8 +69,6 @@ class SOAP(optim.Optimizer): or a callable function that takes the current step as input and returns the frequency. adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates) precondition_1d: Whether to precondition 1D gradients (like biases). - trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix - normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations use_eigh: Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis. @@ -83,23 +81,23 @@ class SOAP(optim.Optimizer): More steps can lead to better convergence but increased computation time. max_update_rms: Clip the update RMS to this value (0 means no clipping). use_kl_shampoo: Whether to use KL-Shampoo correction. + correct_shampoo_beta_bias: Whether to correct shampoo beta bias. Decoupled it from correct_bias for + testability because reference implementation of Soap doesn't bias correct shampoo beta. """ def __init__( self, params: ParamsT, - lr: float = 3e-3, - betas: Tuple[float, float] = (0.95, 0.95), + lr: float, + betas: Tuple[float, float] = (0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8, weight_decay: float = 0.01, use_decoupled_wd: bool = True, use_nesterov: bool = False, precondition_frequency: Union[int, Callable[[int], int]] = 1, - adam_warmup_steps: int = 1, + adam_warmup_steps: int = 0, precondition_1d: bool = False, - trace_normalization: bool = False, - normalize_preconditioned_grads: bool = False, correct_bias: bool = True, fp32_matmul_prec: str = "high", use_eigh: bool = False, @@ -109,12 +107,11 @@ def __init__( power_iter_steps: int = 1, max_update_rms: float = 0.0, use_kl_shampoo: bool = False, + correct_shampoo_beta_bias: bool | None = None, ) -> None: self.precondition_frequency = precondition_frequency self.adam_warmup_steps = adam_warmup_steps self.precondition_1d = precondition_1d - self.trace_normalization = trace_normalization - self.normalize_preconditioned_grads = normalize_preconditioned_grads self.use_nesterov = use_nesterov self.correct_bias = correct_bias self.use_decoupled_wd = use_decoupled_wd @@ -126,6 +123,10 @@ def __init__( self.power_iter_steps = power_iter_steps self.max_update_rms = max_update_rms self.use_kl_shampoo = use_kl_shampoo + if correct_shampoo_beta_bias is not None: + self.correct_shampoo_beta_bias = correct_shampoo_beta_bias + else: + self.correct_shampoo_beta_bias = correct_bias defaults = { "lr": lr, @@ -160,155 +161,132 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: if "step" not in state: state["step"] = 0 - # State initialization - # (TODO @mkhona): Better way to check state initialization - use state initializer? - if "exp_avg" not in state: + # NOTE: The upstream PyTorch implementations increment the step counter in the middle of the loop + # to be used in bias correction. But this is confusing and error prone if anything else needs to use + # the step counter. + # We decided to follow Python and C convention to increment the step counter at the end of the loop. + # An explicitly named 1-based iteration/step counter is created for bias correction and other terms + # in the math equation that needs 1-based iteration count. + curr_iter_1_based = state["step"] + 1 + + # TODO(Mkhona): Improve initialization handling. + # - More protective checks can be added to avoid potential issues with checkpointing. + # - Initializing zero buffers can also be avoided. + if state["step"] == 0: + assert all(key not in state for key in ["exp_avg", "exp_avg_sq", "GG"]), ( + "exp_avg and exp_avg_sq and GG should not be initialized at step 0. " + "Some mismatch has been created likely in checkpointing" + ) # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(grad) - - if "Q" not in state: - state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape] + # Initialize kronecker factor matrices + state["GG"] = init_kronecker_factors( + grad, + precondition_1d=self.precondition_1d, + ) # Define kronecker_factor_update_fn based on whether to use KL-Shampoo here # because it needs access to state and group - kronecker_factor_update_fn = partial(update_kronecker_factors, precondition_1d=self.precondition_1d) - if self.use_kl_shampoo: + if not self.use_kl_shampoo: + kronecker_factor_update_fn = partial( + update_kronecker_factors, + precondition_1d=self.precondition_1d, + ) + else: + if "Q" not in state: + assert state["step"] == 0, ( + f"Q should already be initialized at step {state['step']}, Some mismatch has been created " + "likely in checkpointing" + ) + state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape] kronecker_factor_update_fn = partial( update_kronecker_factors_kl_shampoo, eigenbasis_list=state["Q"], eps=group["eps"], ) - # Initialize kronecker factor matrices - if "GG" not in state: - state["GG"] = init_kronecker_factors( - grad, - precondition_1d=self.precondition_1d, - ) + shampoo_beta = group["shampoo_beta"] + if self.correct_shampoo_beta_bias: + shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based) - # Update preconditioner matrices with gradient statistics, - # do not use shampoo_beta for EMA at first step - with utils.fp32_matmul_precision(self.fp32_matmul_prec): - kronecker_factor_update_fn( - kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"] - ) + torch.cuda.nvtx.range_push("update_kronecker_factors") + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + kronecker_factor_update_fn(kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=shampoo_beta) + torch.cuda.nvtx.range_pop() - # Increment step counter - state["step"] += 1 + # After the adam_warmup_steps are completed , update eigenbases at precondition_frequency steps + torch.cuda.nvtx.range_push("Update eigen basis") + if _is_eigenbasis_update_step( + state["step"], + self.adam_warmup_steps, + self.precondition_frequency, + ): + # Always use eigh for the first eigenbasis update + use_eigh = self.use_eigh if state["step"] != self.adam_warmup_steps else True + + with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): + state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum( + kronecker_factor_list=state["GG"], + eigenbasis_list=state.get("Q", None), + exp_avg_sq=state["exp_avg_sq"], + momentum=state["exp_avg"], + use_eigh=use_eigh, + use_adaptive_criteria=self.use_adaptive_criteria, + adaptive_update_tolerance=self.adaptive_update_tolerance, + power_iter_steps=self.power_iter_steps, + ) + torch.cuda.nvtx.range_pop() - # Apply weight decay if group["weight_decay"] > 0.0: if self.use_decoupled_wd: - # Apply decoupled weight decay p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) else: - # add l2 regularization before preconditioning (i.e. like adding a squared loss term) grad += group["weight_decay"] * p - # Projecting gradients to the eigenbases of Shampoo's preconditioner + grad_projected = grad + # Project gradients to the eigenbases of Shampoo's preconditioner torch.cuda.nvtx.range_push("precondition") - with utils.fp32_matmul_precision(self.fp32_matmul_prec): - grad_projected = precondition( - grad=grad, - eigenbasis_list=state["Q"], - dims=[[0], [0]], - ) + if state["step"] >= self.adam_warmup_steps: + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + grad_projected = precondition( + grad=grad, + eigenbasis_list=state["Q"], + dims=[[0], [0]], + ) torch.cuda.nvtx.range_pop() - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - # Calculate the Adam update for the projected gradient tensor - torch.cuda.nvtx.range_push("calculate_adam_update") adam_update = calculate_adam_update( grad_projected, - exp_avg, - exp_avg_sq, + state["exp_avg"], + state["exp_avg_sq"], group["betas"], self.correct_bias, self.use_nesterov, - state["step"], + curr_iter_1_based, # 1-based iteration index is used for bias correction group["eps"], ) - step_size = group["lr"] - torch.cuda.nvtx.range_pop() # Projecting back the preconditioned (by ADAM) exponential moving average of gradients torch.cuda.nvtx.range_push("precondition") - with utils.fp32_matmul_precision(self.fp32_matmul_prec): - norm_precond_grad = precondition( - grad=adam_update, - eigenbasis_list=state["Q"], - dims=[[0], [1]], - ) - torch.cuda.nvtx.range_pop() - - if self.trace_normalization: - if state["GG"][0].numel() > 0: - trace_normalization = 1 / torch.sqrt(torch.trace(state["GG"][0])) - norm_precond_grad = norm_precond_grad / trace_normalization - - if self.normalize_preconditioned_grads: - norm_precond_grad = norm_precond_grad / (1e-30 + torch.mean(norm_precond_grad**2) ** 0.5) - - # Clip the update RMS to a maximum value - _clip_update_rms_in_place(norm_precond_grad, self.max_update_rms) - - torch.cuda.nvtx.range_push("weight update") - p.add_(norm_precond_grad, alpha=-step_size) - torch.cuda.nvtx.range_pop() - - # Update kronecker factor matrices with gradient statistics - shampoo_beta = group["shampoo_beta"] if group["shampoo_beta"] >= 0 else group["betas"][1] - if self.correct_bias: - # step size correction for shampoo kronecker factors EMA - shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta ** (state["step"] + 1)) - - torch.cuda.nvtx.range_push("update_kronecker_factors") - with utils.fp32_matmul_precision(self.fp32_matmul_prec): - kronecker_factor_update_fn( - kronecker_factor_list=state["GG"], - grad=grad, - shampoo_beta=0.0, - ) - torch.cuda.nvtx.range_pop() - - # If current step is the last step to skip preconditioning, initialize eigenbases and - # end first order warmup - if state["step"] == self.adam_warmup_steps: - # Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition - state["Q"] = get_eigenbasis_eigh(state["GG"]) - # rotate momentum to the new eigenbasis + if state["step"] >= self.adam_warmup_steps: with utils.fp32_matmul_precision(self.fp32_matmul_prec): - state["exp_avg"] = precondition( - grad=state["exp_avg"], - eigenbasis_list=state["Q"], - dims=[[0], [0]], - ) - continue - - # After the adam_warmup_steps are completed. - # Update eigenbases at precondition_frequency steps - torch.cuda.nvtx.range_push("Update eigen basis") - if _is_eigenbasis_update_step( - state["step"], - self.adam_warmup_steps, - self.precondition_frequency, - ): - with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): - state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum( - kronecker_factor_list=state["GG"], - eigenbasis_list=state["Q"], - exp_avg_sq=state["exp_avg_sq"], - momentum=state["exp_avg"], - use_eigh=self.use_eigh, - use_adaptive_criteria=self.use_adaptive_criteria, - adaptive_update_tolerance=self.adaptive_update_tolerance, - power_iter_steps=self.power_iter_steps, + precond_update = precondition( + grad=adam_update, + eigenbasis_list=state.get("Q", None), + dims=[[0], [1]], ) + else: + precond_update = adam_update torch.cuda.nvtx.range_pop() + _clip_update_rms_in_place(precond_update, self.max_update_rms) + p.add_(precond_update, alpha=-group["lr"]) + + state["step"] += 1 + return loss @@ -581,7 +559,7 @@ def update_eigenbasis_and_momentum( @torch.compile # type: ignore[misc] def precondition( grad: torch.Tensor, - eigenbasis_list: Optional[List[torch.Tensor]], + eigenbasis_list: Optional[List[torch.Tensor]] = None, dims: Optional[List[List[int]]] = None, ) -> torch.Tensor: """Projects the gradient to and from the eigenbases of the kronecker factor matrices. @@ -607,7 +585,7 @@ def precondition( # Pick contraction dims to project to the eigenbasis dims = [[0], [0]] - if not eigenbasis_list: + if eigenbasis_list is None: # If eigenbases are not provided, return the gradient without any preconditioning return grad @@ -653,7 +631,7 @@ def _is_eigenbasis_update_step( @torch.compile # type: ignore[misc] -def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float = 1.0, eps: float = 1e-12) -> None: +def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7) -> None: """Clip the update root mean square (RMS) to a maximum value, in place. Do not clip if max_rms is 0. diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index 439ef4a..2876269 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -72,7 +72,7 @@ def eigh_with_fallback( # Add small identity for numerical stability eye = torch.eye( - x.shape[-1], + x.shape[0], device=x.device, dtype=x.dtype, ) diff --git a/tests/soap_mnist_test.py b/tests/soap_mnist_test.py index 99faea9..8078bf4 100644 --- a/tests/soap_mnist_test.py +++ b/tests/soap_mnist_test.py @@ -47,14 +47,10 @@ def forward(self, x): "eps": 1e-8, "precondition_1d": True, # Enable preconditioning for bias vectors "precondition_frequency": 1, # Update preconditioner every step for testing - "trace_normalization": True, - "shampoo_beta": 0.9, # Slightly more aggressive moving average "fp32_matmul_prec": "high", "qr_fp32_matmul_prec": "high", "use_adaptive_criteria": False, "power_iter_steps": 1, - "use_nesterov": True, - "skip_preconditioning_steps": 0, } @@ -106,15 +102,12 @@ def main() -> None: # Initialize optimizers optimizer_soap = SOAP( model_soap.parameters(), - lr=9.0 * config["lr"], + lr=2.05 * config["lr"], weight_decay=config["weight_decay"], betas=(config["adam_beta1"], config["adam_beta2"]), eps=config["eps"], precondition_frequency=config["precondition_frequency"], - trace_normalization=config["trace_normalization"], - shampoo_beta=config["shampoo_beta"], precondition_1d=config["precondition_1d"], - use_nesterov=config["use_nesterov"], fp32_matmul_prec=config["fp32_matmul_prec"], qr_fp32_matmul_prec=config["qr_fp32_matmul_prec"], use_adaptive_criteria=config["use_adaptive_criteria"], diff --git a/tests/soap_reference.py b/tests/soap_reference.py new file mode 100644 index 0000000..7c9be14 --- /dev/null +++ b/tests/soap_reference.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# MIT License + +# Copyright (c) 2024 Nikhil Vyas + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from itertools import chain +from typing import Tuple + +import torch +import torch.optim as optim + + +class ReferenceSoap(optim.Optimizer): + """Reference implementation of SOAP algorithm + + https://arxiv.org/abs/2409.11321. + + Note: + Order of operations are slightly changed from original code to match our algorithmic choices. + """ + + def __init__( + self, + params, + lr: float, + betas: Tuple[float, float], + shampoo_beta: float, + eps: float, + weight_decay: float, + precondition_frequency: int, + max_precond_dim: int = 10000, + merge_dims: bool = False, + precondition_1d: bool = False, + data_format: str = "channels_first", + correct_bias: bool = True, + ): + defaults = { + "lr": lr, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + "precondition_frequency": precondition_frequency, + "max_precond_dim": max_precond_dim, + "merge_dims": merge_dims, + "precondition_1d": precondition_1d, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + self._data_format = data_format + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + self.init_preconditioner( + grad, + state, + precondition_frequency=group["precondition_frequency"], + precondition_1d=group["precondition_1d"], + shampoo_beta=(group["shampoo_beta"] if group["shampoo_beta"] >= 0 else group["betas"][1]), + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + ) + # NOTE: We don't skip first update + + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + + # Projecting gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + grad_projected = self.project( + grad, state, merge_dims=group["merge_dims"], max_precond_dim=group["max_precond_dim"] + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2)) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + exp_avg_projected = exp_avg + + step_size = group["lr"] + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** (state["step"]) + bias_correction2 = 1.0 - beta2 ** (state["step"]) + step_size = step_size * (bias_correction2**0.5) / bias_correction1 + + # Projecting back the preconditioned (by Adam) exponential moving average of gradients + # to the original space + norm_grad = self.project_back( + exp_avg_projected / denom, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + # NOTE: Weigth decay is moved before parameter update + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + p.add_(norm_grad, alpha=-step_size) + + return loss + + def init_preconditioner( + self, + grad, + state, + precondition_frequency=10, + shampoo_beta=0.95, + max_precond_dim=10000, + precondition_1d=False, + merge_dims=False, + ): + """ + Initializes the preconditioner matrices (L and R in the paper). + """ + state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper). + if grad.dim() == 1: + if not precondition_1d or grad.shape[0] > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)) + else: + if merge_dims: + grad = self.merge_dims(grad, max_precond_dim) + + for sh in grad.shape: + if sh > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append(torch.zeros(sh, sh, device=grad.device)) + + state["Q"] = None # Will hold all the eigenbases of the preconditioner. + state["precondition_frequency"] = precondition_frequency + state["shampoo_beta"] = shampoo_beta + + def project(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient to the eigenbases of the preconditioner. + """ + original_shape = grad.shape + if merge_dims: + if grad.dim() == 4 and self._data_format == "channels_last": + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [0]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def update_preconditioner(self, grad, state, max_precond_dim=10000, merge_dims=False, precondition_1d=False): + """ + Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). + """ + if state["Q"] is not None: + state["exp_avg"] = self.project_back( + state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim + ) + if grad.dim() == 1: + if precondition_1d and grad.shape[0] <= max_precond_dim: + state["GG"][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]) + else: + if merge_dims: + new_grad = self.merge_dims(grad, max_precond_dim) + for idx, sh in enumerate(new_grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + new_grad, + new_grad, + dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + else: + for idx, sh in enumerate(grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + grad, + grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + + if state["Q"] is None: + state["Q"] = self.get_orthogonal_matrix(state["GG"]) + if state["step"] > 0 and (state["step"]) % state["precondition_frequency"] == 0: + state["Q"] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims) + + if state["step"] > 0: + state["exp_avg"] = self.project( + state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim + ) + + # print("wtf1", state["exp_avg"]) + + def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient back to the original space. + """ + original_shape = grad.shape + if merge_dims: + if self._data_format == "channels_last" and grad.dim() == 4: + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [1]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def get_orthogonal_matrix(self, mat): + """ + Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. + """ + matrix = [] + for m in mat: + if len(m) == 0: + matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + else: + float_data = True + matrix.append(m.data) + + final = [] + for m in matrix: + if len(m) == 0: + final.append([]) + continue + + try: + _, Q = torch.linalg.eigh(m + 1e-7 * torch.eye(m.shape[0], device=m.device)) + except: + _, Q = torch.linalg.eigh(m.to(torch.float64) + 1e-15 * torch.eye(m.shape[0], device=m.device)) + Q = Q.to(m.dtype) + Q = torch.flip(Q, [1]) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + return final + + def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False): + """ + Computes the eigenbases of the preconditioner using one round of power iteration + followed by torch.linalg.qr decomposition. + """ + precond_list = state["GG"] + orth_list = state["Q"] + + matrix = [] + orth_matrix = [] + for m, o in zip(precond_list, orth_list): + if len(m) == 0: + matrix.append([]) + orth_matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + else: + float_data = True + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + + orig_shape = state["exp_avg_sq"].shape + if self._data_format == "channels_last" and len(orig_shape) == 4: + permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape + if merge_dims: + exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim) + else: + exp_avg_sq = state["exp_avg_sq"] + + final = [] + for ind, (m, o) in enumerate(zip(matrix, orth_matrix)): + if len(m) == 0: + final.append([]) + continue + est_eig = torch.diag(o.T @ m @ o) + sort_idx = torch.argsort(est_eig, descending=True) + exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) + o = o[:, sort_idx] + power_iter = m @ o + Q, _ = torch.linalg.qr(power_iter) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + + if merge_dims: + if self._data_format == "channels_last" and len(orig_shape) == 4: + exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + exp_avg_sq = exp_avg_sq.reshape(orig_shape) + + state["exp_avg_sq"] = exp_avg_sq + return final diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index 8e8244b..800cbe8 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -43,15 +43,14 @@ def setUp(self): # Set up device based on flags self.device = FLAGS.device - def test_calculate_adam_update_simple(self) -> None: - exp_avg_initial = torch.tensor([[1.0]], device=self.device) - exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) - grad = torch.tensor([[0.5]], device=self.device) - - # Move tensors to the test device - exp_avg_initial = exp_avg_initial.to(self.device) - exp_avg_sq_initial = exp_avg_sq_initial.to(self.device) - grad = grad.to(self.device) + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_calculate_adam_update_simple(self, shape) -> None: + exp_avg_initial = torch.full(shape, 1.0, device=self.device) + exp_avg_sq_initial = torch.full(shape, 2.0, device=self.device) + grad = torch.full(shape, 0.5, device=self.device) betas = (0.9, 0.99) eps = 1e-8 @@ -73,7 +72,7 @@ def test_calculate_adam_update_simple(self) -> None: eps=eps, ) - initial_param_val_tensor = torch.tensor([[10.0]]).to(self.device) + initial_param_val_tensor = torch.full(shape, 10.0, device=self.device) param = torch.nn.Parameter(initial_param_val_tensor.clone()) param.grad = grad.clone() @@ -106,8 +105,6 @@ def test_calculate_adam_update_simple(self) -> None: expected_param_val_after_step = initial_param_val_tensor - lr * manual_update_value torch.testing.assert_close(param.data, expected_param_val_after_step, atol=1e-6, rtol=1e-6) - self.assertEqual(manual_update_value.shape, (1, 1)) - def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None: # LaProp with momentum (beta1) = 0 should be equivalent to RMSProp. exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0 diff --git a/tests/test_soap.py b/tests/test_soap.py index f7864d0..8445f2d 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -16,6 +16,7 @@ from functools import partial from typing import Any, List +import soap_reference import torch from absl.testing import absltest, parameterized @@ -65,6 +66,10 @@ def kl_shampoo_update_ref( class SoapFunctionsTest(parameterized.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(13) + def test_init_preconditioner_multidim_tensor_shapes(self) -> None: """Tests init_preconditioner with a multi-dimensional tensor.""" grad = torch.randn(3, 4, 5) @@ -93,17 +98,14 @@ def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None: precondition_frequency=1, ) - dummy_Q = [torch.eye(shape, device=param.device) for shape in param.shape] - for step in range(adam_warmup_steps - 1): + for step in range(adam_warmup_steps): param.grad = torch.randn_like(param) optimizer.step() state = optimizer.state[param] - torch.testing.assert_close( - state["Q"], dummy_Q, atol=0, rtol=0, msg=f"Q should stay identity at step {step}" - ) + self.assertNotIn("Q", state) - for step in range(adam_warmup_steps - 1, adam_warmup_steps + 3): + for step in range(adam_warmup_steps, adam_warmup_steps + 3): param.grad = torch.randn_like(param) optimizer.step() state = optimizer.state[param] @@ -307,6 +309,10 @@ def test_kl_shampoo_update(self, m, n): class SoapTest(parameterized.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(15) + def setUp(self): self.default_config = { "lr": 0.001, @@ -319,7 +325,6 @@ def setUp(self): "adam_warmup_steps": 1, "fp32_matmul_prec": "highest", "use_adaptive_criteria": False, - "trace_normalization": False, "power_iter_steps": 1, } @@ -349,5 +354,143 @@ def test_with_kl_shampoo_10steps_smoke(self): param.grad = None +class SoapVsReferenceTest(parameterized.TestCase): + """Tests that compare SOAP implementation against reference implementation.""" + + @classmethod + def setUpClass(cls): + torch.manual_seed(17) + + @parameterized.product( + shape=[(3, 3), (5, 3), (10, 10), (15, 31)], + num_steps=[2, 5, 7], + precondition_frequency=[1, 2, 5], + correct_bias=[False, True], + ) + def test_update_matches_reference( + self, shape: tuple, num_steps: int, precondition_frequency: int, correct_bias: bool + ): + """Test that SOAP optimizer matches reference implementation for basic config.""" + # Create two identical parameters + param_test = torch.randint(-2, 3, shape, dtype=torch.float32, device="cuda") + param_ref = param_test.clone() + + # NOTE: eps is smaller than usual because reference implementation of Soap applies eps differently than + # torch.optim.AdamW when correct_bias is True. + if correct_bias and shape == (15, 31): + self.skipTest("Skipping large tensor test with correct_bias.") + common_kwargs = dict( + lr=2, + betas=(0.75, 0.75), + shampoo_beta=0.5, + eps=1e-15, + weight_decay=0.125, + precondition_frequency=precondition_frequency, + correct_bias=correct_bias, + ) + + test_optimizer = soap.SOAP( + [param_test], + **common_kwargs, + use_decoupled_wd=True, + adam_warmup_steps=0, + fp32_matmul_prec="highest", + qr_fp32_matmul_prec="highest", + correct_shampoo_beta_bias=False, + ) + ref_optimizer = soap_reference.ReferenceSoap( + [param_ref], + **common_kwargs, + ) + # Run optimization steps with identical gradients + for step in range(num_steps): + grad = torch.randint_like(param_test, -2, 3) + + # Apply same gradient to both + param_test.grad = grad.clone() + param_ref.grad = grad.clone() + + # Step both optimizers + test_optimizer.step() + ref_optimizer.step() + + torch.testing.assert_close( + param_test, + param_ref, + atol=1e-4, + rtol=1e-4, + msg=lambda msg: f"Parameter mismatch at step {step}:\n{msg}", + ) + + param_test.grad = None + param_ref.grad = None + + @parameterized.product( + shape=[(3, 3), (5, 3), (10, 10), (15, 31)], + num_steps=[2, 5, 7], + precondition_frequency=[1, 2, 5], + ) + def test_eigenbasis_matches_reference(self, shape: tuple, num_steps: int, precondition_frequency: int): + param_soap = torch.randint(-2, 3, shape, dtype=torch.float32, device="cuda") + param_ref = param_soap.clone() + + # Disable parameter updates, only test kronecker factors and eigenbases + common_kwargs = dict( + lr=0, + betas=(0, 0), + shampoo_beta=0.75, + eps=1e-8, + weight_decay=0, + precondition_frequency=precondition_frequency, + correct_bias=False, + ) + + test_optimizer = soap.SOAP( + [param_soap], + **common_kwargs, + use_decoupled_wd=False, + adam_warmup_steps=0, + fp32_matmul_prec="highest", + qr_fp32_matmul_prec="highest", + ) + ref_optimizer = soap_reference.ReferenceSoap( + [param_ref], + **common_kwargs, + ) + + for step in range(num_steps): + grad = torch.randint_like(param_soap, -2, 3) + param_soap.grad = grad.clone() + param_ref.grad = grad.clone() + + test_optimizer.step() + ref_optimizer.step() + + param_soap.grad = None + param_ref.grad = None + + test_state = test_optimizer.state[param_soap] + ref_state = ref_optimizer.state[param_ref] + + torch.testing.assert_close( + test_state["GG"], + ref_state["GG"], + atol=1e-5, + rtol=1e-5, + ) + + for eigenbasis_test, eigenbasis_ref in zip(test_state["Q"], ref_state["Q"]): + torch.testing.assert_close( + eigenbasis_test, + eigenbasis_ref, + atol=1e-4, + rtol=1e-4, + msg=lambda msg: f"Eigenbasis mismatch at step {step}:\n{msg}", + ) + + # Compare step counters + self.assertEqual(test_state["step"], ref_state["step"]) + + if __name__ == "__main__": absltest.main()