Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,14 @@ def __init__(
scaled_orthogonalize_fn = torch.nn.Identity()

self.fp32_matmul_prec = fp32_matmul_prec
self.use_nesterov = use_nesterov
self.use_decoupled_wd = use_decoupled_wd
self.use_independent_wd = use_independent_wd

default_args_dict = dict(
lr=lr,
momentum_beta=momentum_beta,
use_nesterov=use_nesterov,
weight_decay=weight_decay,
use_decoupled_wd=use_decoupled_wd,
use_independent_wd=use_independent_wd,
**kwargs,
)

Expand Down Expand Up @@ -156,9 +157,9 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:

# Apply weight decay
if group["weight_decay"] > 0.0:
if group["use_decoupled_wd"]:
if self.use_decoupled_wd:
# Apply weight decay directly to params without changing gradients
if group["use_independent_wd"]:
if self.use_independent_wd:
# do not tie weight decay and learning rate
weight_decay_scale = group["weight_decay"]
else:
Expand All @@ -172,7 +173,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
exp_avg.lerp_(grad, 1 - group["momentum_beta"])

# include nesterov momentum
if group["use_nesterov"]:
if self.use_nesterov:
grad = grad.lerp(exp_avg, group["momentum_beta"])
else:
grad = exp_avg
Expand Down
25 changes: 12 additions & 13 deletions emerging_optimizers/psgd/psgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class PSGDPro(torch.optim.Optimizer):
params: Iterable of parameters to optimize or dicts defining parameter groups
lr: The learning rate to use
weight_decay: Weight decay coefficient
use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101.
momentum: Momentum coefficient for exponential moving average of gradient.
beta_lip: EMA beta for the Lipschitz constants.
Expand All @@ -59,7 +59,7 @@ def __init__(
params: ParamsT,
lr: float = 3e-3,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
use_decoupled_wd: bool = True,
momentum: float = 0.9,
beta_lip: float = 0.9,
precond_lr: float = 0.1,
Expand All @@ -69,18 +69,18 @@ def __init__(
warmup_steps: int = 10000,
max_update_rms: float = 0.0,
) -> None:
self.use_decoupled_wd = use_decoupled_wd
self.max_update_rms = max_update_rms
self.precond_init_scale = precond_init_scale
self.damping_noise_scale = damping_noise_scale
self.warmup_steps = warmup_steps
defaults = {
"lr": lr,
"beta_lip": beta_lip,
"weight_decay": weight_decay,
"use_decoupled_weight_decay": use_decoupled_weight_decay,
"momentum": momentum,
"precond_lr": precond_lr,
"precond_init_scale": precond_init_scale,
"max_update_rms": max_update_rms,
"min_precond_lr": min_precond_lr,
"warmup_steps": warmup_steps,
"damping_noise_scale": damping_noise_scale,
}
super().__init__(params, defaults)

Expand Down Expand Up @@ -114,12 +114,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
if "Q" not in state or "L" not in state:
state["Q"], state["L"] = _init_psgd_kron_states(
grad,
precond_init_scale=group["precond_init_scale"],
precond_init_scale=self.precond_init_scale,
)

# weight decay
if group["weight_decay"] > 0.0:
if group["use_decoupled_weight_decay"]:
if self.use_decoupled_wd:
# Apply decoupled weight decay
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
else:
Expand All @@ -131,21 +131,20 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
exp_avg.lerp_(grad, 1 - group["momentum"])

# Get hyperparameters for preconditioner update
damping_noise_scale = group["damping_noise_scale"]
precond_lr = _get_precond_lr(
group["precond_lr"], state["step"], group["min_precond_lr"], group["warmup_steps"]
group["precond_lr"], state["step"], group["min_precond_lr"], self.warmup_steps
)

beta_lip = group["beta_lip"]
# Preconditioner update
state["Q"], state["L"] = _update_precond_procrustes(
state["Q"], state["L"], exp_avg, damping_noise_scale, precond_lr, beta_lip
state["Q"], state["L"], exp_avg, self.damping_noise_scale, precond_lr, beta_lip
)
uniformize_q_in_place(state["Q"])

# Get weight update by preconditioning the momentum
update = apply_preconditioner(state["Q"], exp_avg)
_clip_update_rms_in_place(update, group["max_update_rms"])
_clip_update_rms_in_place(update, self.max_update_rms)

# Apply weight update
p.add_(update, alpha=-group["lr"])
Expand Down
Loading