Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions emerging_optimizers/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

import torch


WeightDecayT = Literal["decoupled", "independent", "l2"]


class WeightDecayMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this to be a class rather than a set of functions that are chosen based on arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight decay is a function highly coupled with optimizer, and shared among a lot of optim subclass.

Thought about a optimizer based class with this function and make everyone inherit from it, but not all of our optimizers use the same base.

"""Mixin for weight decay
Supports different types of weight decay:
- "decoupled": weight decay is applied directly to params without changing gradients
- "independent": similar as decoupled weight decay, but without tying weight decay and learning rate
- "l2": classic L2 regularization
"""

def _apply_weight_decay_inplace(
self,
p: torch.Tensor,
grad: torch.Tensor,
lr: float,
weight_decay: float,
) -> None:
"""Depends on the weight decay option, p or grad will be updated in place"""
if weight_decay == 0.0:
return

weight_decay_method = getattr(self, "weight_decay_method", "l2")
if weight_decay_method == "decoupled":
p.add_(p, alpha=(-weight_decay * lr))
elif weight_decay_method == "independent":
p.add_(p, alpha=-weight_decay)
elif weight_decay_method == "l2":
grad.add_(p, alpha=weight_decay)
else:
raise ValueError(f"Invalid weight decay method: {weight_decay_method}")
27 changes: 16 additions & 11 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from torch.optim.optimizer import ParamsT

from emerging_optimizers import triton_kernels
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
from emerging_optimizers.mixin import WeightDecayT
from emerging_optimizers.orthogonalized_optimizers import muon_utils
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc


Expand Down Expand Up @@ -64,10 +65,10 @@ def __init__(
params: ParamsT,
lr: float = 3e-4,
momentum_beta: float = 0.95,
use_nesterov: bool = False,
weight_decay: float = 0.01,
use_decoupled_wd: bool = True,
use_independent_wd: bool = False,
*,
use_nesterov: bool = False,
weight_decay_method: WeightDecayT = "decoupled",
fp32_matmul_prec: str = "medium",
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
Expand Down Expand Up @@ -97,20 +98,24 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, "
f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
)
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk)
orth_grad = muon_utils.newton_schulz(
grad,
steps=num_ns_steps,
coefficient_type=coefficient_type,
use_syrk=use_syrk,
)
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
return orth_grad * scale_factor * extra_scale_factor

super().__init__(
params,
lr,
momentum_beta,
use_nesterov,
weight_decay,
use_decoupled_wd,
use_independent_wd,
fp32_matmul_prec,
scaled_orthogonalize_fn,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import utils


_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate used by the internal SGD.
momentum_beta: The momentum used by the internal SGD.
use_nesterov: Whether to use Nesterov-style momentum in the internal SGD.
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_decoupled_wd: Whether to use decoupled weight decay, default to be True.
use_independent_wd: Whether to use independent weight decay (https://arxiv.org/abs/2510.19093),
default to be False.
use_nesterov: Whether to use Nesterov-style momentum in the internal SGD.
weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin`
for more details.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
"""


class OrthogonalizedOptimizer(optim.Optimizer):
class OrthogonalizedOptimizer(opt_mixin.WeightDecayMixin, optim.Optimizer):
"""Base class for orthogonalized optimizers.

This class is a wrapper around a base optimizer that performs orthogonalization on the updates.
Expand Down Expand Up @@ -99,10 +99,10 @@ def __init__(
params: ParamsT,
lr: float,
momentum_beta: float,
use_nesterov: bool,
weight_decay: float,
use_decoupled_wd: bool,
use_independent_wd: bool,
*,
use_nesterov: bool,
weight_decay_method: opt_mixin.WeightDecayT,
fp32_matmul_prec: str,
scaled_orthogonalize_fn: Callable | None = None,
**kwargs: Any,
Expand All @@ -113,8 +113,7 @@ def __init__(

self.fp32_matmul_prec = fp32_matmul_prec
self.use_nesterov = use_nesterov
self.use_decoupled_wd = use_decoupled_wd
self.use_independent_wd = use_independent_wd
self.weight_decay_method = weight_decay_method

default_args_dict = dict(
lr=lr,
Expand Down Expand Up @@ -155,19 +154,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
# Subsequent update to exp_avg are all inplace, so it is not assigned back to state.
exp_avg = state["momentum_buffer"]

# Apply weight decay
if group["weight_decay"] > 0.0:
if self.use_decoupled_wd:
# Apply weight decay directly to params without changing gradients
if self.use_independent_wd:
# do not tie weight decay and learning rate
weight_decay_scale = group["weight_decay"]
else:
weight_decay_scale = group["weight_decay"] * group["lr"]
p.add_(p, alpha=(-weight_decay_scale))
else:
# add l2 regularization before preconditioning (i.e. adding a squared loss term)
grad += group["weight_decay"] * p
self._apply_weight_decay_inplace(
p,
grad,
group["lr"],
group["weight_decay"],
)

# update momentum buffer with EMA of gradient
exp_avg.lerp_(grad, 1 - group["momentum_beta"])
Expand Down
21 changes: 9 additions & 12 deletions emerging_optimizers/orthogonalized_optimizers/scion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
params: ParamsT,
lr: float = 3e-4,
momentum_beta: float = 0.95,
*,
fp32_matmul_prec: str = "medium",
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
Expand All @@ -69,14 +70,11 @@ def __init__(
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")

# Add checks for weight decay arguments to enable Franke-Wolfe step.
logging.info("Scion does not use weight decay. Setting weight_decay to 1.")
logging.info(
"Scion does not use weight decay. Setting weight_decay to 1 and weight_decay_method to decoupled."
)
weight_decay = 1

logging.info("Scion does not use weight decay. Setting use_decoupled_wd to True to allow Franke-Wolfe.")
use_decoupled_wd = True

logging.info("Scion does not use weight decay. Setting use_independent_wd to False to allow Franke-Wolfe.")
use_independent_wd = False
weight_decay_method = "decoupled"

logging.info("Scion does not use Nesterov momentum. Setting use_nesterov to False.")
use_nesterov = False
Expand All @@ -93,10 +91,9 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
params,
lr,
momentum_beta,
use_nesterov,
weight_decay,
use_decoupled_wd,
use_independent_wd,
fp32_matmul_prec,
scaled_orthogonalize_fn,
use_nesterov=use_nesterov,
weight_decay_method=weight_decay_method, # type: ignore[arg-type]
fp32_matmul_prec=fp32_matmul_prec,
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
)
39 changes: 19 additions & 20 deletions emerging_optimizers/psgd/psgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers.psgd import psgd_kron_contractions, psgd_utils
from emerging_optimizers.psgd.procrustes_step import procrustes_step
from emerging_optimizers.psgd.psgd_kron_contractions import apply_preconditioner, partial_contraction
from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_spd, uniformize_q_in_place
from emerging_optimizers.soap.soap import _clip_update_rms_in_place


Expand All @@ -29,7 +29,7 @@
]


class PSGDPro(torch.optim.Optimizer):
class PSGDPro(opt_mixin.WeightDecayMixin, torch.optim.Optimizer):
"""Implements a variant of the PSGD optimization algorithm (PSGD-Kron-Whiten with Procrustes step for preconditioner update).

Preconditioned Stochastic Gradient Descent (PSGD) (https://arxiv.org/abs/1512.04202) is a preconditioned optimization algorithm
Expand All @@ -42,8 +42,8 @@ class PSGDPro(torch.optim.Optimizer):
params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate to use
weight_decay: Weight decay coefficient
use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101.
weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin`
for more details.
momentum: Momentum coefficient for exponential moving average of gradient.
beta_lip: EMA beta for the Lipschitz constants.
precond_lr: Inner learning rate for the preconditioner.
Expand All @@ -59,8 +59,9 @@ def __init__(
params: ParamsT,
lr: float = 3e-3,
weight_decay: float = 0.01,
use_decoupled_wd: bool = True,
momentum: float = 0.9,
*,
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
beta_lip: float = 0.9,
precond_lr: float = 0.1,
precond_init_scale: float = 1.0,
Expand All @@ -69,7 +70,7 @@ def __init__(
warmup_steps: int = 10000,
max_update_rms: float = 0.0,
) -> None:
self.use_decoupled_wd = use_decoupled_wd
self.weight_decay_method = weight_decay_method
self.max_update_rms = max_update_rms
self.precond_init_scale = precond_init_scale
self.damping_noise_scale = damping_noise_scale
Expand Down Expand Up @@ -117,14 +118,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
precond_init_scale=self.precond_init_scale,
)

# weight decay
if group["weight_decay"] > 0.0:
if self.use_decoupled_wd:
# Apply decoupled weight decay
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
else:
# add l2 regularization before preconditioning (i.e. adding a squared loss term)
grad += group["weight_decay"] * p
self._apply_weight_decay_inplace(
p,
grad,
group["lr"],
group["weight_decay"],
)

# update momentum buffer with EMA of gradient
exp_avg = state["exp_avg"]
Expand All @@ -140,10 +139,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
state["Q"], state["L"] = _update_precond_procrustes(
state["Q"], state["L"], exp_avg, self.damping_noise_scale, precond_lr, beta_lip
)
uniformize_q_in_place(state["Q"])
psgd_utils.uniformize_q_in_place(state["Q"])

# Get weight update by preconditioning the momentum
update = apply_preconditioner(state["Q"], exp_avg)
update = psgd_kron_contractions.apply_preconditioner(state["Q"], exp_avg)
_clip_update_rms_in_place(update, self.max_update_rms)

# Apply weight update
Expand Down Expand Up @@ -200,13 +199,13 @@ def _update_precond_procrustes(
lip_const_list: List of Lipschitz constants for the Kronecker factors.
"""
dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg)
pg = apply_preconditioner(q_list, dampened_momentum)
pg = psgd_kron_contractions.apply_preconditioner(q_list, dampened_momentum)
total_numel = pg.numel()
updated_q_list: List[torch.Tensor] = []
updated_lip_const_list: List[torch.Tensor] = []
for dim, q in enumerate(q_list):
# compute gradient covariance
precond_grad_cov = partial_contraction(pg, pg, dim)
precond_grad_cov = psgd_kron_contractions.partial_contraction(pg, pg, dim)
if q.dim() < 2:
# diagonal or scalar-structured preconditioner
q, updated_lip_const = _update_1d_preconditioner(
Expand Down Expand Up @@ -246,7 +245,7 @@ def _update_matrix_preconditioner(
lip_const: Updated Lipschitz constant for this dimension.
"""
normalization = total_numel / q.shape[0]
ell = norm_lower_bound_spd(precond_grad_cov) + normalization
ell = psgd_utils.norm_lower_bound_spd(precond_grad_cov) + normalization
lip_const = torch.max(beta_lip * lip_const + (1 - beta_lip) * ell, ell)
q = q - precond_lr / lip_const * (precond_grad_cov @ q - normalization * q)
q = procrustes_step(q)
Expand Down
Loading