diff --git a/emerging_optimizers/mixin.py b/emerging_optimizers/mixin.py new file mode 100644 index 0000000..508166e --- /dev/null +++ b/emerging_optimizers/mixin.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +WeightDecayT = Literal["decoupled", "independent", "l2"] + + +class WeightDecayMixin: + """Mixin for weight decay + + Supports different types of weight decay: + - "decoupled": weight decay is applied directly to params without changing gradients + - "independent": similar as decoupled weight decay, but without tying weight decay and learning rate + - "l2": classic L2 regularization + """ + + def _apply_weight_decay_inplace( + self, + p: torch.Tensor, + grad: torch.Tensor, + lr: float, + weight_decay: float, + ) -> None: + """Depends on the weight decay option, p or grad will be updated in place""" + if weight_decay == 0.0: + return + + weight_decay_method = getattr(self, "weight_decay_method", "l2") + if weight_decay_method == "decoupled": + p.add_(p, alpha=(-weight_decay * lr)) + elif weight_decay_method == "independent": + p.add_(p, alpha=-weight_decay) + elif weight_decay_method == "l2": + grad.add_(p, alpha=weight_decay) + else: + raise ValueError(f"Invalid weight decay method: {weight_decay_method}") diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 7a585e8..8200fb1 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -18,7 +18,8 @@ from torch.optim.optimizer import ParamsT from emerging_optimizers import triton_kernels -from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz +from emerging_optimizers.mixin import WeightDecayT +from emerging_optimizers.orthogonalized_optimizers import muon_utils from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc @@ -64,10 +65,10 @@ def __init__( params: ParamsT, lr: float = 3e-4, momentum_beta: float = 0.95, - use_nesterov: bool = False, weight_decay: float = 0.01, - use_decoupled_wd: bool = True, - use_independent_wd: bool = False, + *, + use_nesterov: bool = False, + weight_decay_method: WeightDecayT = "decoupled", fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, @@ -97,7 +98,12 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, " f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" ) - orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk) + orth_grad = muon_utils.newton_schulz( + grad, + steps=num_ns_steps, + coefficient_type=coefficient_type, + use_syrk=use_syrk, + ) scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) return orth_grad * scale_factor * extra_scale_factor @@ -105,12 +111,11 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: params, lr, momentum_beta, - use_nesterov, - weight_decay, - use_decoupled_wd, - use_independent_wd, - fp32_matmul_prec, - scaled_orthogonalize_fn, + use_nesterov=use_nesterov, + weight_decay=weight_decay, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + scaled_orthogonalize_fn=scaled_orthogonalize_fn, ) diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 4d65e62..78dcae1 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -26,23 +26,23 @@ from absl import logging from torch.optim.optimizer import ParamsT +from emerging_optimizers import mixin as opt_mixin from emerging_optimizers import utils _args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups lr: The learning rate used by the internal SGD. momentum_beta: The momentum used by the internal SGD. - use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 - use_decoupled_wd: Whether to use decoupled weight decay, default to be True. - use_independent_wd: Whether to use independent weight decay (https://arxiv.org/abs/2510.19093), - default to be False. + use_nesterov: Whether to use Nesterov-style momentum in the internal SGD. + weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` + for more details. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. """ -class OrthogonalizedOptimizer(optim.Optimizer): +class OrthogonalizedOptimizer(opt_mixin.WeightDecayMixin, optim.Optimizer): """Base class for orthogonalized optimizers. This class is a wrapper around a base optimizer that performs orthogonalization on the updates. @@ -99,10 +99,10 @@ def __init__( params: ParamsT, lr: float, momentum_beta: float, - use_nesterov: bool, weight_decay: float, - use_decoupled_wd: bool, - use_independent_wd: bool, + *, + use_nesterov: bool, + weight_decay_method: opt_mixin.WeightDecayT, fp32_matmul_prec: str, scaled_orthogonalize_fn: Callable | None = None, **kwargs: Any, @@ -113,8 +113,7 @@ def __init__( self.fp32_matmul_prec = fp32_matmul_prec self.use_nesterov = use_nesterov - self.use_decoupled_wd = use_decoupled_wd - self.use_independent_wd = use_independent_wd + self.weight_decay_method = weight_decay_method default_args_dict = dict( lr=lr, @@ -155,19 +154,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Subsequent update to exp_avg are all inplace, so it is not assigned back to state. exp_avg = state["momentum_buffer"] - # Apply weight decay - if group["weight_decay"] > 0.0: - if self.use_decoupled_wd: - # Apply weight decay directly to params without changing gradients - if self.use_independent_wd: - # do not tie weight decay and learning rate - weight_decay_scale = group["weight_decay"] - else: - weight_decay_scale = group["weight_decay"] * group["lr"] - p.add_(p, alpha=(-weight_decay_scale)) - else: - # add l2 regularization before preconditioning (i.e. adding a squared loss term) - grad += group["weight_decay"] * p + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) # update momentum buffer with EMA of gradient exp_avg.lerp_(grad, 1 - group["momentum_beta"]) diff --git a/emerging_optimizers/orthogonalized_optimizers/scion.py b/emerging_optimizers/orthogonalized_optimizers/scion.py index f2ea66c..3dce5fe 100644 --- a/emerging_optimizers/orthogonalized_optimizers/scion.py +++ b/emerging_optimizers/orthogonalized_optimizers/scion.py @@ -60,6 +60,7 @@ def __init__( params: ParamsT, lr: float = 3e-4, momentum_beta: float = 0.95, + *, fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, @@ -69,14 +70,11 @@ def __init__( raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") # Add checks for weight decay arguments to enable Franke-Wolfe step. - logging.info("Scion does not use weight decay. Setting weight_decay to 1.") + logging.info( + "Scion does not use weight decay. Setting weight_decay to 1 and weight_decay_method to decoupled." + ) weight_decay = 1 - - logging.info("Scion does not use weight decay. Setting use_decoupled_wd to True to allow Franke-Wolfe.") - use_decoupled_wd = True - - logging.info("Scion does not use weight decay. Setting use_independent_wd to False to allow Franke-Wolfe.") - use_independent_wd = False + weight_decay_method = "decoupled" logging.info("Scion does not use Nesterov momentum. Setting use_nesterov to False.") use_nesterov = False @@ -93,10 +91,9 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: params, lr, momentum_beta, - use_nesterov, weight_decay, - use_decoupled_wd, - use_independent_wd, - fp32_matmul_prec, - scaled_orthogonalize_fn, + use_nesterov=use_nesterov, + weight_decay_method=weight_decay_method, # type: ignore[arg-type] + fp32_matmul_prec=fp32_matmul_prec, + scaled_orthogonalize_fn=scaled_orthogonalize_fn, ) diff --git a/emerging_optimizers/psgd/psgd.py b/emerging_optimizers/psgd/psgd.py index 28fed2e..a8e28bb 100644 --- a/emerging_optimizers/psgd/psgd.py +++ b/emerging_optimizers/psgd/psgd.py @@ -18,9 +18,9 @@ import torch from torch.optim.optimizer import ParamsT +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers.psgd import psgd_kron_contractions, psgd_utils from emerging_optimizers.psgd.procrustes_step import procrustes_step -from emerging_optimizers.psgd.psgd_kron_contractions import apply_preconditioner, partial_contraction -from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_spd, uniformize_q_in_place from emerging_optimizers.soap.soap import _clip_update_rms_in_place @@ -29,7 +29,7 @@ ] -class PSGDPro(torch.optim.Optimizer): +class PSGDPro(opt_mixin.WeightDecayMixin, torch.optim.Optimizer): """Implements a variant of the PSGD optimization algorithm (PSGD-Kron-Whiten with Procrustes step for preconditioner update). Preconditioned Stochastic Gradient Descent (PSGD) (https://arxiv.org/abs/1512.04202) is a preconditioned optimization algorithm @@ -42,8 +42,8 @@ class PSGDPro(torch.optim.Optimizer): params: Iterable of parameters to optimize or dicts defining parameter groups lr: The learning rate to use weight_decay: Weight decay coefficient - use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101. + weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` + for more details. momentum: Momentum coefficient for exponential moving average of gradient. beta_lip: EMA beta for the Lipschitz constants. precond_lr: Inner learning rate for the preconditioner. @@ -59,8 +59,9 @@ def __init__( params: ParamsT, lr: float = 3e-3, weight_decay: float = 0.01, - use_decoupled_wd: bool = True, momentum: float = 0.9, + *, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", beta_lip: float = 0.9, precond_lr: float = 0.1, precond_init_scale: float = 1.0, @@ -69,7 +70,7 @@ def __init__( warmup_steps: int = 10000, max_update_rms: float = 0.0, ) -> None: - self.use_decoupled_wd = use_decoupled_wd + self.weight_decay_method = weight_decay_method self.max_update_rms = max_update_rms self.precond_init_scale = precond_init_scale self.damping_noise_scale = damping_noise_scale @@ -117,14 +118,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: precond_init_scale=self.precond_init_scale, ) - # weight decay - if group["weight_decay"] > 0.0: - if self.use_decoupled_wd: - # Apply decoupled weight decay - p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) - else: - # add l2 regularization before preconditioning (i.e. adding a squared loss term) - grad += group["weight_decay"] * p + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) # update momentum buffer with EMA of gradient exp_avg = state["exp_avg"] @@ -140,10 +139,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: state["Q"], state["L"] = _update_precond_procrustes( state["Q"], state["L"], exp_avg, self.damping_noise_scale, precond_lr, beta_lip ) - uniformize_q_in_place(state["Q"]) + psgd_utils.uniformize_q_in_place(state["Q"]) # Get weight update by preconditioning the momentum - update = apply_preconditioner(state["Q"], exp_avg) + update = psgd_kron_contractions.apply_preconditioner(state["Q"], exp_avg) _clip_update_rms_in_place(update, self.max_update_rms) # Apply weight update @@ -200,13 +199,13 @@ def _update_precond_procrustes( lip_const_list: List of Lipschitz constants for the Kronecker factors. """ dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg) - pg = apply_preconditioner(q_list, dampened_momentum) + pg = psgd_kron_contractions.apply_preconditioner(q_list, dampened_momentum) total_numel = pg.numel() updated_q_list: List[torch.Tensor] = [] updated_lip_const_list: List[torch.Tensor] = [] for dim, q in enumerate(q_list): # compute gradient covariance - precond_grad_cov = partial_contraction(pg, pg, dim) + precond_grad_cov = psgd_kron_contractions.partial_contraction(pg, pg, dim) if q.dim() < 2: # diagonal or scalar-structured preconditioner q, updated_lip_const = _update_1d_preconditioner( @@ -246,7 +245,7 @@ def _update_matrix_preconditioner( lip_const: Updated Lipschitz constant for this dimension. """ normalization = total_numel / q.shape[0] - ell = norm_lower_bound_spd(precond_grad_cov) + normalization + ell = psgd_utils.norm_lower_bound_spd(precond_grad_cov) + normalization lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell) q = q - precond_lr / lip_const * (precond_grad_cov @ q - normalization * q) q = procrustes_step(q) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 2649c6a..36d5190 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -24,16 +24,13 @@ from typing_extensions import override import torch -import torch.optim as optim from absl import logging +from torch import optim from torch.optim.optimizer import ParamsT -from emerging_optimizers import utils -from emerging_optimizers.scalar_optimizers import calculate_adam_update -from emerging_optimizers.soap.soap_utils import ( - get_eigenbasis_eigh, - get_eigenbasis_qr, -) +from emerging_optimizers import mixin as opt_mixin +from emerging_optimizers import scalar_optimizers, utils +from emerging_optimizers.soap import soap_utils __all__ = [ @@ -45,7 +42,7 @@ ] -class SOAP(optim.Optimizer): +class SOAP(opt_mixin.WeightDecayMixin, optim.Optimizer): """Implements a variant of SOAP (ShampoO with Adam in the Preconditioner eigenbasis) algorithm. SOAP (https://arxiv.org/abs/2409.11321) is a preconditioned optimizer that combines the benefits of Shampoo's @@ -61,8 +58,8 @@ class SOAP(optim.Optimizer): instead of betas[1] if >= 0 eps: Inner Adam's epsilon for numerical stability weight_decay: Weight decay coefficient - use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101. + weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` + for more details. use_nesterov: uses Nesterov momentum in Adam (https://cs229.stanford.edu/proj2015/054_report.pdf, https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) precondition_frequency: How often to update the preconditioner. Can be an integer for fixed frequency @@ -93,7 +90,8 @@ def __init__( shampoo_beta: float = 0.95, eps: float = 1e-8, weight_decay: float = 0.01, - use_decoupled_wd: bool = True, + *, + weight_decay_method: opt_mixin.WeightDecayT = "decoupled", use_nesterov: bool = False, precondition_frequency: Union[int, Callable[[int], int]] = 1, adam_warmup_steps: int = 0, @@ -114,7 +112,7 @@ def __init__( self.precondition_1d = precondition_1d self.use_nesterov = use_nesterov self.correct_bias = correct_bias - self.use_decoupled_wd = use_decoupled_wd + self.weight_decay_method = weight_decay_method self.fp32_matmul_prec = fp32_matmul_prec self.use_eigh = use_eigh self.qr_fp32_matmul_prec = qr_fp32_matmul_prec @@ -239,11 +237,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: ) torch.cuda.nvtx.range_pop() - if group["weight_decay"] > 0.0: - if self.use_decoupled_wd: - p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) - else: - grad += group["weight_decay"] * p + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) grad_projected = grad # Project gradients to the eigenbases of Shampoo's preconditioner @@ -258,7 +257,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: torch.cuda.nvtx.range_pop() # Calculate the Adam update for the projected gradient tensor - adam_update = calculate_adam_update( + adam_update = scalar_optimizers.calculate_adam_update( grad_projected, state["exp_avg"], state["exp_avg_sq"], @@ -434,7 +433,8 @@ def update_kronecker_factors_kl_shampoo( eps: Small offset for numerical stability. eigenval_exp: Exponent of the eigenvalues. """ - assert grad.dim() == 2, "KL-Shampoo mathematical correction is only supported for 2D tensors" + if grad.dim() != 2: + raise TypeError("KL-Shampoo mathematical correction is only supported for 2D tensors") # Scale the gradient matrix by the approximate eigenvalues and the eigenbasis # G@Q_R@λ_R^(−1)@Q_R.T@G.T/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG) @@ -523,7 +523,7 @@ def update_eigenbasis_and_momentum( # Step 2: Update eigenbases torch.cuda.nvtx.range_push("eigenbasis update step 2: update Q") if use_eigh: - updated_eigenbasis_list = get_eigenbasis_eigh( + updated_eigenbasis_list = soap_utils.get_eigenbasis_eigh( kronecker_factor_list, convert_to_float, eigenbasis_list, @@ -532,7 +532,7 @@ def update_eigenbasis_and_momentum( ) else: # Use QR decomposition and power iteration (orthogonal iteration) - updated_eigenbasis_list, exp_avg_sq = get_eigenbasis_qr( + updated_eigenbasis_list, exp_avg_sq = soap_utils.get_eigenbasis_qr( kronecker_factor_list, eigenbasis_list, exp_avg_sq, diff --git a/emerging_optimizers/triton_kernels/syrk.py b/emerging_optimizers/triton_kernels/syrk.py index 49673e5..d8a1c0e 100644 --- a/emerging_optimizers/triton_kernels/syrk.py +++ b/emerging_optimizers/triton_kernels/syrk.py @@ -203,15 +203,18 @@ def tsyrk_ex( Returns: Output tensor of shape (N, N) """ - assert a.dtype == torch.bfloat16, "Input tensor must be bfloat16" - assert a.dim() == 2, "Input tensor must be 2D" - assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous." + if a.dtype != torch.bfloat16: + raise TypeError("Input tensor must be bfloat16") + if a.dim() != 2: + raise TypeError("Input tensor must be 2D") + if not (a.is_contiguous() or a.T.is_contiguous()): + raise TypeError("invalid input tensor layout. a or a.T must be contiguous.") N, K = a.shape - assert (c is None and beta == 0.0) or (c is not None and c.shape == (N, N)), ( - "if c is provided, c must be of shape (N, N)" - ) - assert c is None or c.is_contiguous() or c.T.is_contiguous(), "if c is provided, c or c.T must be contiguous" + if not ((c is None and beta == 0.0) or (c is not None and c.shape == (N, N))): + raise RuntimeError("if c is provided, c must be of shape (N, N)") + if not (c is None or c.is_contiguous() or c.T.is_contiguous()): + raise RuntimeError("if c is provided, c or c.T must be contiguous") d = torch.empty((N, N), device=a.device, dtype=a.dtype) diff --git a/pyproject.toml b/pyproject.toml index 4f7c8cc..3b8bbfb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,9 +177,13 @@ source = ["emerging_optimizers/", "/workspace/emerging_optimizers"] [tool.coverage.report] exclude_lines = [ - "raise ValueError" + "except ImportError", ] exclude_also = [ - "@triton" + "@triton", + ".*sm_version", + "if closure", + "loss = closure", + "raise .*Error", ] diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index cf882b3..ca6a1a9 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -14,9 +14,9 @@ export TORCH_COMPILE_DISABLE=1 error=0 -torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py || error=1 -torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu || error=1 -coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu || error=1 +torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 +torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1 exit "${error}" diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index c07e1d4..bd28ef8 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -16,18 +16,18 @@ export CUDA_VISIBLE_DEVICES=0 export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 error=0 -coverage run -p --source=emerging_optimizers tests/test_muon_utils.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_soap_utils.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_soap.py || error=1 -coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda || error=1 -coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_triton_kernels.py TsyrkIntegerInputTest || error=1 -coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda || error=1 -coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda || error=1 -coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda || error=1 -coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda || error=1 -coverage run -p --source=emerging_optimizers tests/test_psgd_convergence.py --device=cuda || error=1 +coverage run -p --source=emerging_optimizers tests/test_muon_utils.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_soap_utils.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_soap.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_triton_kernels.py TsyrkIntegerInputTest -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda -v -2 || error=1 +coverage run -p --source=emerging_optimizers tests/test_psgd_convergence.py --device=cuda -v -2 || error=1 exit "${error}" diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index e3a7892..39312c1 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -17,11 +17,32 @@ import torch.nn as nn from absl.testing import absltest, parameterized -from emerging_optimizers.orthogonalized_optimizers import muon +from emerging_optimizers.orthogonalized_optimizers import muon, scion from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer class OrthogonalizedOptimizerTest(parameterized.TestCase): + @parameterized.product( + weight_decay_method=["decoupled", "independent", "l2"], + shape=[(5, 7), (33, 65), (127, 257)], + use_nesterov=[True, False], + fp32_matmul_prec=["highest", "medium", "low"], + ) + def test_smoke(self, weight_decay_method, shape, use_nesterov, fp32_matmul_prec) -> None: + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + + orthogonalized_opt = OrthogonalizedOptimizer( + [test_param], + lr=2, + momentum_beta=0, + weight_decay=0.5, + use_nesterov=use_nesterov, + weight_decay_method=weight_decay_method, + fp32_matmul_prec=fp32_matmul_prec, + ) + orthogonalized_opt.step() + @parameterized.parameters( {"shape": (5, 7)}, {"shape": (33, 65)}, @@ -42,8 +63,7 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None: momentum_beta=0, use_nesterov=False, weight_decay=0.5, - use_decoupled_wd=True, - use_independent_wd=False, + weight_decay_method="decoupled", fp32_matmul_prec="highest", ) @@ -84,8 +104,7 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> momentum_beta=0.5, use_nesterov=False, weight_decay=0.0, - use_decoupled_wd=False, - use_independent_wd=False, + weight_decay_method="l2", fp32_matmul_prec="highest", ) @@ -135,8 +154,7 @@ def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor: momentum_beta=0, use_nesterov=False, weight_decay=0.0, - use_decoupled_wd=False, - use_independent_wd=False, + weight_decay_method="l2", fp32_matmul_prec="highest", scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn, ) @@ -154,12 +172,12 @@ def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor: class MuonTest(parameterized.TestCase): - @parameterized.parameters( - {"shape": (5, 7)}, - {"shape": (33, 65)}, - {"shape": (127, 257)}, + @parameterized.product( + shape=[(5, 7), (33, 65), (127, 257)], + weight_decay_method=["decoupled", "independent", "l2"], + use_nesterov=[True, False], ) - def test_smoke(self, shape) -> None: + def test_smoke(self, shape, weight_decay_method, use_nesterov) -> None: """Smoke test Muon optimizer. Most functionality of muon is tested in muon_utils. This test only entures everything run through the optimizer class. @@ -167,7 +185,7 @@ def test_smoke(self, shape) -> None: test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) test_param.grad = torch.randint_like(test_param, -5, 5) - muon_opt = muon.Muon([test_param]) + muon_opt = muon.Muon([test_param], weight_decay_method=weight_decay_method, use_nesterov=use_nesterov) muon_opt.step() def test_use_syrk_match_without_syrk(self) -> None: @@ -195,28 +213,41 @@ def test_use_independent_wd(self) -> None: # Test with independent weight decay: with lr=0, weight decay should still be applied # With lr=0, no gradient update occurs, so param should be exactly (1-wd)*param - indep_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) - indep_param_initial = indep_param.data.clone() - indep_param.grad = torch.randint_like(indep_param, -5, 5) + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + # With independent weight decay and lr=0, param should be exactly (1-wd)*param + expected_param = (1 - weight_decay) * test_param.data muon_opt_indep = muon.Muon( - [indep_param], + [test_param], lr=0.0, # Zero learning rate weight_decay=weight_decay, - use_independent_wd=True, + weight_decay_method="independent", momentum_beta=0.0, ) muon_opt_indep.step() - # With independent weight decay and lr=0, param should be exactly (1-wd)*param - expected_param = (1 - weight_decay) * indep_param_initial torch.testing.assert_close( - indep_param.data, + test_param, expected_param, atol=0, rtol=0, ) +class ScionTest(parameterized.TestCase): + @parameterized.parameters( + {"shape": (5, 7)}, + {"shape": (33, 65)}, + {"shape": (127, 257)}, + ) + def test_smoke(self, shape) -> None: + test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda")) + test_param.grad = torch.randint_like(test_param, -5, 5) + + scion_opt = scion.Scion([test_param]) + scion_opt.step() + + if __name__ == "__main__": absltest.main() diff --git a/tests/test_soap.py b/tests/test_soap.py index 8445f2d..83b61fa 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -392,7 +392,6 @@ def test_update_matches_reference( test_optimizer = soap.SOAP( [param_test], **common_kwargs, - use_decoupled_wd=True, adam_warmup_steps=0, fp32_matmul_prec="highest", qr_fp32_matmul_prec="highest", @@ -448,7 +447,7 @@ def test_eigenbasis_matches_reference(self, shape: tuple, num_steps: int, precon test_optimizer = soap.SOAP( [param_soap], **common_kwargs, - use_decoupled_wd=False, + weight_decay_method="l2", adam_warmup_steps=0, fp32_matmul_prec="highest", qr_fp32_matmul_prec="highest",