diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 27d10ab..4d65e62 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -112,13 +112,14 @@ def __init__( scaled_orthogonalize_fn = torch.nn.Identity() self.fp32_matmul_prec = fp32_matmul_prec + self.use_nesterov = use_nesterov + self.use_decoupled_wd = use_decoupled_wd + self.use_independent_wd = use_independent_wd + default_args_dict = dict( lr=lr, momentum_beta=momentum_beta, - use_nesterov=use_nesterov, weight_decay=weight_decay, - use_decoupled_wd=use_decoupled_wd, - use_independent_wd=use_independent_wd, **kwargs, ) @@ -156,9 +157,9 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Apply weight decay if group["weight_decay"] > 0.0: - if group["use_decoupled_wd"]: + if self.use_decoupled_wd: # Apply weight decay directly to params without changing gradients - if group["use_independent_wd"]: + if self.use_independent_wd: # do not tie weight decay and learning rate weight_decay_scale = group["weight_decay"] else: @@ -172,7 +173,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: exp_avg.lerp_(grad, 1 - group["momentum_beta"]) # include nesterov momentum - if group["use_nesterov"]: + if self.use_nesterov: grad = grad.lerp(exp_avg, group["momentum_beta"]) else: grad = exp_avg diff --git a/emerging_optimizers/psgd/psgd.py b/emerging_optimizers/psgd/psgd.py index 34dd628..28fed2e 100644 --- a/emerging_optimizers/psgd/psgd.py +++ b/emerging_optimizers/psgd/psgd.py @@ -42,7 +42,7 @@ class PSGDPro(torch.optim.Optimizer): params: Iterable of parameters to optimize or dicts defining parameter groups lr: The learning rate to use weight_decay: Weight decay coefficient - use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: + use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101. momentum: Momentum coefficient for exponential moving average of gradient. beta_lip: EMA beta for the Lipschitz constants. @@ -59,7 +59,7 @@ def __init__( params: ParamsT, lr: float = 3e-3, weight_decay: float = 0.01, - use_decoupled_weight_decay: bool = True, + use_decoupled_wd: bool = True, momentum: float = 0.9, beta_lip: float = 0.9, precond_lr: float = 0.1, @@ -69,18 +69,18 @@ def __init__( warmup_steps: int = 10000, max_update_rms: float = 0.0, ) -> None: + self.use_decoupled_wd = use_decoupled_wd + self.max_update_rms = max_update_rms + self.precond_init_scale = precond_init_scale + self.damping_noise_scale = damping_noise_scale + self.warmup_steps = warmup_steps defaults = { "lr": lr, "beta_lip": beta_lip, "weight_decay": weight_decay, - "use_decoupled_weight_decay": use_decoupled_weight_decay, "momentum": momentum, "precond_lr": precond_lr, - "precond_init_scale": precond_init_scale, - "max_update_rms": max_update_rms, "min_precond_lr": min_precond_lr, - "warmup_steps": warmup_steps, - "damping_noise_scale": damping_noise_scale, } super().__init__(params, defaults) @@ -114,12 +114,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: if "Q" not in state or "L" not in state: state["Q"], state["L"] = _init_psgd_kron_states( grad, - precond_init_scale=group["precond_init_scale"], + precond_init_scale=self.precond_init_scale, ) # weight decay if group["weight_decay"] > 0.0: - if group["use_decoupled_weight_decay"]: + if self.use_decoupled_wd: # Apply decoupled weight decay p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) else: @@ -131,21 +131,20 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: exp_avg.lerp_(grad, 1 - group["momentum"]) # Get hyperparameters for preconditioner update - damping_noise_scale = group["damping_noise_scale"] precond_lr = _get_precond_lr( - group["precond_lr"], state["step"], group["min_precond_lr"], group["warmup_steps"] + group["precond_lr"], state["step"], group["min_precond_lr"], self.warmup_steps ) beta_lip = group["beta_lip"] # Preconditioner update state["Q"], state["L"] = _update_precond_procrustes( - state["Q"], state["L"], exp_avg, damping_noise_scale, precond_lr, beta_lip + state["Q"], state["L"], exp_avg, self.damping_noise_scale, precond_lr, beta_lip ) uniformize_q_in_place(state["Q"]) # Get weight update by preconditioning the momentum update = apply_preconditioner(state["Q"], exp_avg) - _clip_update_rms_in_place(update, group["max_update_rms"]) + _clip_update_rms_in_place(update, self.max_update_rms) # Apply weight update p.add_(update, alpha=-group["lr"]) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 0d1bd2a..5a4de6e 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -14,7 +14,7 @@ # limitations under the License. from functools import partial from itertools import chain -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union # TODO(@boxiangw): remove this once bump to python 3.12 @@ -26,6 +26,7 @@ import torch import torch.optim as optim from absl import logging +from torch.optim.optimizer import ParamsT from emerging_optimizers import utils from emerging_optimizers.scalar_optimizers import calculate_adam_update @@ -60,13 +61,12 @@ class SOAP(optim.Optimizer): instead of betas[1] if >= 0 eps: Inner Adam's epsilon for numerical stability weight_decay: Weight decay coefficient - use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: + use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101. use_nesterov: uses Nesterov momentum in Adam (https://cs229.stanford.edu/proj2015/054_report.pdf, https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) precondition_frequency: How often to update the preconditioner. Can be an integer for fixed frequency or a callable function that takes the current step as input and returns the frequency. - precondition_warmup_steps: How many steps to warm up the preconditioner (i.e. update every step) adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates) precondition_1d: Whether to precondition 1D gradients (like biases). trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix @@ -87,16 +87,15 @@ class SOAP(optim.Optimizer): def __init__( self, - params: Iterable[torch.nn.parameter.Parameter], + params: ParamsT, lr: float = 3e-3, - betas: Optional[Tuple[float, float]] = None, - shampoo_beta: float = -1, + betas: Tuple[float, float] = (0.95, 0.95), + shampoo_beta: float = 0.95, eps: float = 1e-8, weight_decay: float = 0.01, - use_decoupled_weight_decay: bool = True, + use_decoupled_wd: bool = True, use_nesterov: bool = False, - precondition_frequency: Union[int, Callable[[int], int]] = 10, - precondition_warmup_steps: int = 0, + precondition_frequency: Union[int, Callable[[int], int]] = 1, adam_warmup_steps: int = 1, precondition_1d: bool = False, trace_normalization: bool = False, @@ -106,39 +105,27 @@ def __init__( use_eigh: bool = False, qr_fp32_matmul_prec: str = "high", use_adaptive_criteria: bool = False, - adaptive_update_tolerance: Optional[float] = None, + adaptive_update_tolerance: float = 1e-7, power_iter_steps: int = 1, max_update_rms: float = 0.0, use_kl_shampoo: bool = False, ) -> None: - # Check for betas. - if betas is None: - betas = (0.95, 0.95) - logging.debug(f"betas not provided. Setting betas equal to betas = {betas} by default.") - - # Check for update criteria - if use_adaptive_criteria: - if adaptive_update_tolerance is None: - adaptive_update_tolerance = 1e-7 - logging.info( - "adaptive_update_tolerance not provided. Setting adaptive_update_tolerance equal to " - f"eps = {adaptive_update_tolerance} by default." - ) - - # Check for adam_warmup_steps since <1 will cause key errors in update_eigenbasis_and_momentum step - if adam_warmup_steps < 1: - adam_warmup_steps = 1 - logging.info("adam_warmup_steps is less than 1. Setting adam_warmup_steps to 1 by default.") - - # Check for precondition warmup steps and adam warmup steps - if adam_warmup_steps >= precondition_warmup_steps and precondition_warmup_steps > 0: - original_adam_warmup_steps = adam_warmup_steps - adam_warmup_steps = max(1, precondition_warmup_steps - 1) - logging.info( - f"adam_warmup_steps ({original_adam_warmup_steps}) should be less " - f"than precondition_warmup_steps ({precondition_warmup_steps}). " - f"Setting adam_warmup_steps to {adam_warmup_steps} by default." - ) + self.precondition_frequency = precondition_frequency + self.adam_warmup_steps = adam_warmup_steps + self.precondition_1d = precondition_1d + self.trace_normalization = trace_normalization + self.normalize_preconditioned_grads = normalize_preconditioned_grads + self.use_nesterov = use_nesterov + self.correct_bias = correct_bias + self.use_decoupled_wd = use_decoupled_wd + self.fp32_matmul_prec = fp32_matmul_prec + self.use_eigh = use_eigh + self.qr_fp32_matmul_prec = qr_fp32_matmul_prec + self.use_adaptive_criteria = use_adaptive_criteria + self.adaptive_update_tolerance = adaptive_update_tolerance + self.power_iter_steps = power_iter_steps + self.max_update_rms = max_update_rms + self.use_kl_shampoo = use_kl_shampoo defaults = { "lr": lr, @@ -146,23 +133,6 @@ def __init__( "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay, - "precondition_frequency": precondition_frequency, - "precondition_warmup_steps": precondition_warmup_steps, - "adam_warmup_steps": adam_warmup_steps, - "precondition_1d": precondition_1d, - "trace_normalization": trace_normalization, - "normalize_preconditioned_grads": normalize_preconditioned_grads, - "use_nesterov": use_nesterov, - "correct_bias": correct_bias, - "use_decoupled_weight_decay": use_decoupled_weight_decay, - "fp32_matmul_prec": fp32_matmul_prec, - "use_eigh": use_eigh, - "qr_fp32_matmul_prec": qr_fp32_matmul_prec, - "use_adaptive_criteria": use_adaptive_criteria, - "adaptive_update_tolerance": adaptive_update_tolerance, - "power_iter_steps": power_iter_steps, - "max_update_rms": max_update_rms, - "use_kl_shampoo": use_kl_shampoo, } super().__init__(params, defaults) @@ -203,10 +173,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Define kronecker_factor_update_fn based on whether to use KL-Shampoo here # because it needs access to state and group - kronecker_factor_update_fn = partial( - update_kronecker_factors, precondition_1d=group["precondition_1d"] - ) - if group["use_kl_shampoo"]: + kronecker_factor_update_fn = partial(update_kronecker_factors, precondition_1d=self.precondition_1d) + if self.use_kl_shampoo: kronecker_factor_update_fn = partial( update_kronecker_factors_kl_shampoo, eigenbasis_list=state["Q"], @@ -217,12 +185,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: if "GG" not in state: state["GG"] = init_kronecker_factors( grad, - precondition_1d=group["precondition_1d"], + precondition_1d=self.precondition_1d, ) # Update preconditioner matrices with gradient statistics, # do not use shampoo_beta for EMA at first step - with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): + with utils.fp32_matmul_precision(self.fp32_matmul_prec): kronecker_factor_update_fn( kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"] ) @@ -232,7 +200,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Apply weight decay if group["weight_decay"] > 0.0: - if group["use_decoupled_weight_decay"]: + if self.use_decoupled_wd: # Apply decoupled weight decay p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) else: @@ -241,7 +209,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Projecting gradients to the eigenbases of Shampoo's preconditioner torch.cuda.nvtx.range_push("precondition") - with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): + with utils.fp32_matmul_precision(self.fp32_matmul_prec): grad_projected = precondition( grad=grad, eigenbasis_list=state["Q"], @@ -258,8 +226,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: exp_avg, exp_avg_sq, group["betas"], - group["correct_bias"], - group["use_nesterov"], + self.correct_bias, + self.use_nesterov, state["step"], group["eps"], ) @@ -268,7 +236,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Projecting back the preconditioned (by ADAM) exponential moving average of gradients torch.cuda.nvtx.range_push("precondition") - with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): + with utils.fp32_matmul_precision(self.fp32_matmul_prec): norm_precond_grad = precondition( grad=adam_update, eigenbasis_list=state["Q"], @@ -276,16 +244,16 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: ) torch.cuda.nvtx.range_pop() - if group["trace_normalization"]: + if self.trace_normalization: if state["GG"][0].numel() > 0: trace_normalization = 1 / torch.sqrt(torch.trace(state["GG"][0])) norm_precond_grad = norm_precond_grad / trace_normalization - if group["normalize_preconditioned_grads"]: + if self.normalize_preconditioned_grads: norm_precond_grad = norm_precond_grad / (1e-30 + torch.mean(norm_precond_grad**2) ** 0.5) # Clip the update RMS to a maximum value - _clip_update_rms_in_place(norm_precond_grad, group["max_update_rms"]) + _clip_update_rms_in_place(norm_precond_grad, self.max_update_rms) torch.cuda.nvtx.range_push("weight update") p.add_(norm_precond_grad, alpha=-step_size) @@ -293,12 +261,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Update kronecker factor matrices with gradient statistics shampoo_beta = group["shampoo_beta"] if group["shampoo_beta"] >= 0 else group["betas"][1] - if group["correct_bias"]: + if self.correct_bias: # step size correction for shampoo kronecker factors EMA shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta ** (state["step"] + 1)) torch.cuda.nvtx.range_push("update_kronecker_factors") - with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): + with utils.fp32_matmul_precision(self.fp32_matmul_prec): kronecker_factor_update_fn( kronecker_factor_list=state["GG"], grad=grad, @@ -308,11 +276,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # If current step is the last step to skip preconditioning, initialize eigenbases and # end first order warmup - if state["step"] == group["adam_warmup_steps"]: + if state["step"] == self.adam_warmup_steps: # Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition state["Q"] = get_eigenbasis_eigh(state["GG"]) # rotate momentum to the new eigenbasis - with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): + with utils.fp32_matmul_precision(self.fp32_matmul_prec): state["exp_avg"] = precondition( grad=state["exp_avg"], eigenbasis_list=state["Q"], @@ -320,25 +288,24 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: ) continue - # Update eigenbases at precondition_frequency steps or until precondition_warmup_steps is done, - # but only after the adam_warmup_steps are completed. + # After the adam_warmup_steps are completed. + # Update eigenbases at precondition_frequency steps torch.cuda.nvtx.range_push("Update eigen basis") if _is_eigenbasis_update_step( state["step"], - group["adam_warmup_steps"], - group["precondition_warmup_steps"], - group["precondition_frequency"], + self.adam_warmup_steps, + self.precondition_frequency, ): - with utils.fp32_matmul_precision(group["qr_fp32_matmul_prec"]): + with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum( kronecker_factor_list=state["GG"], eigenbasis_list=state["Q"], exp_avg_sq=state["exp_avg_sq"], momentum=state["exp_avg"], - use_eigh=group["use_eigh"], - use_adaptive_criteria=group["use_adaptive_criteria"], - adaptive_update_tolerance=group["adaptive_update_tolerance"], - power_iter_steps=group["power_iter_steps"], + use_eigh=self.use_eigh, + use_adaptive_criteria=self.use_adaptive_criteria, + adaptive_update_tolerance=self.adaptive_update_tolerance, + power_iter_steps=self.power_iter_steps, ) torch.cuda.nvtx.range_pop() @@ -662,26 +629,9 @@ def precondition( return grad -def _get_precondition_frequency(precondition_frequency: Union[int, Callable[[int], int]], step: int) -> int: - """Get the current precondition frequency based on the schedule or fixed value. - - Args: - precondition_frequency: Either an integer for fixed frequency or a callable that takes step and returns frequency - step: Current optimization step - - Returns: - The precondition frequency for the current step - """ - if callable(precondition_frequency): - return precondition_frequency(step) - else: - return precondition_frequency - - def _is_eigenbasis_update_step( step: int, adam_warmup_steps: int, - precondition_warmup_steps: int, precondition_frequency: Union[int, Callable[[int], int]], ) -> bool: """Checks if amortized computation of the eigenbasis should be recomputed. @@ -689,19 +639,16 @@ def _is_eigenbasis_update_step( Args: step: Current step of the optimizer adam_warmup_steps: Number of steps to skip preconditioning in the beginning (i.e. use standard AdamW updates) - precondition_warmup_steps: How many steps to warm up the preconditioner (i.e. update every step) precondition_frequency: How often to update the preconditioner. Can be an integer for fixed frequency or a callable function that takes the current step as input and returns the frequency. """ - if step <= adam_warmup_steps: + if step < adam_warmup_steps: return False - # During warmup period, update every step - if step <= precondition_warmup_steps: - return True + current_frequency = ( + precondition_frequency if not callable(precondition_frequency) else precondition_frequency(step) + ) - # After warmup, use the scheduled frequency - current_frequency = _get_precondition_frequency(precondition_frequency, step) return step % current_frequency == 0 diff --git a/tests/test_soap.py b/tests/test_soap.py index 51cd11c..f7864d0 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -22,7 +22,6 @@ from emerging_optimizers.soap import soap from emerging_optimizers.soap.soap import ( _clip_update_rms_in_place, - _get_precondition_frequency, _is_eigenbasis_update_step, ) from emerging_optimizers.utils.precondition_schedules import LinearSchedule @@ -216,52 +215,46 @@ def test_project_and_project_back(self, N: int, M: int) -> None: msg="Project and project_back did not recover the original tensor.", ) - def test_get_precondition_frequency_fixed(self) -> None: - """Test that _get_precondition_frequency works with fixed frequency (default case).""" - freq = _get_precondition_frequency(10, 100) - self.assertEqual(freq, 10) - @parameterized.parameters( - (5, 10, 20, 10, False), - (15, 10, 20, 10, True), - (20, 10, 15, 10, True), - (21, 10, 15, 10, False), - (30, 10, 15, 10, True), - (31, 10, 15, 10, False), + (5, 10, 10, False), + (15, 10, 5, True), + (20, 10, 10, True), + (21, 10, 10, False), + (30, 10, 10, True), + (31, 10, 10, False), ) def test_is_eigenbasis_update_step_fixed_frequency( - self, step: int, adam_warmup_steps: int, precondition_warmup: int, precondition_frequency: int, expected: bool + self, step: int, adam_warmup_steps: int, precondition_frequency: int, expected: bool ) -> None: """Test _is_eigenbasis_update_step with fixed frequency.""" - result = _is_eigenbasis_update_step(step, adam_warmup_steps, precondition_warmup, precondition_frequency) + result = _is_eigenbasis_update_step(step, adam_warmup_steps, precondition_frequency) self.assertEqual(result, expected) def test_soap_optimizer_fixed_frequency(self) -> None: """Test that SOAP optimizer can be created with fixed precondition frequency (default case).""" param = torch.randn(10, 5, requires_grad=True) optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=10) - self.assertEqual(optimizer.param_groups[0]["precondition_frequency"], 10) + self.assertEqual(optimizer.precondition_frequency, 10) def test_soap_optimizer_class_based_schedule(self) -> None: """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" param = torch.randn(10, 5, requires_grad=True) schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=schedule) - self.assertTrue((optimizer.param_groups[0]["precondition_frequency"]) == schedule) + self.assertTrue(optimizer.precondition_frequency == schedule) self.assertEqual(schedule(0), 2) self.assertEqual(schedule(50), 6) self.assertEqual(schedule(100), 10) adam_warmup = 1 - precondition_warmup = 0 - - self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, precondition_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, precondition_warmup, schedule)) - self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, precondition_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, precondition_warmup, schedule)) - self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, precondition_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, precondition_warmup, schedule)) + + self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule)) + self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule)) + self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule)) @parameterized.parameters( (1.0,),