Skip to content

Commit 1bb3f72

Browse files
authored
Clean up API (#67)
* clean up soap API Signed-off-by: Hao Wu <[email protected]> * clean psgd API Signed-off-by: Hao Wu <[email protected]> * clean up OrthogonalizedOptimizer API Signed-off-by: Hao Wu <[email protected]> * use ParamT for soap. some import order change. Signed-off-by: Hao Wu <[email protected]> --------- Signed-off-by: Hao Wu <[email protected]>
1 parent 19d8201 commit 1bb3f72

File tree

4 files changed

+90
-150
lines changed

4 files changed

+90
-150
lines changed

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,14 @@ def __init__(
112112
scaled_orthogonalize_fn = torch.nn.Identity()
113113

114114
self.fp32_matmul_prec = fp32_matmul_prec
115+
self.use_nesterov = use_nesterov
116+
self.use_decoupled_wd = use_decoupled_wd
117+
self.use_independent_wd = use_independent_wd
118+
115119
default_args_dict = dict(
116120
lr=lr,
117121
momentum_beta=momentum_beta,
118-
use_nesterov=use_nesterov,
119122
weight_decay=weight_decay,
120-
use_decoupled_wd=use_decoupled_wd,
121-
use_independent_wd=use_independent_wd,
122123
**kwargs,
123124
)
124125

@@ -156,9 +157,9 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
156157

157158
# Apply weight decay
158159
if group["weight_decay"] > 0.0:
159-
if group["use_decoupled_wd"]:
160+
if self.use_decoupled_wd:
160161
# Apply weight decay directly to params without changing gradients
161-
if group["use_independent_wd"]:
162+
if self.use_independent_wd:
162163
# do not tie weight decay and learning rate
163164
weight_decay_scale = group["weight_decay"]
164165
else:
@@ -172,7 +173,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
172173
exp_avg.lerp_(grad, 1 - group["momentum_beta"])
173174

174175
# include nesterov momentum
175-
if group["use_nesterov"]:
176+
if self.use_nesterov:
176177
grad = grad.lerp(exp_avg, group["momentum_beta"])
177178
else:
178179
grad = exp_avg

emerging_optimizers/psgd/psgd.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class PSGDPro(torch.optim.Optimizer):
4242
params: Iterable of parameters to optimize or dicts defining parameter groups
4343
lr: The learning rate to use
4444
weight_decay: Weight decay coefficient
45-
use_decoupled_weight_decay: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
45+
use_decoupled_wd: Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
4646
https://arxiv.org/abs/1711.05101.
4747
momentum: Momentum coefficient for exponential moving average of gradient.
4848
beta_lip: EMA beta for the Lipschitz constants.
@@ -59,7 +59,7 @@ def __init__(
5959
params: ParamsT,
6060
lr: float = 3e-3,
6161
weight_decay: float = 0.01,
62-
use_decoupled_weight_decay: bool = True,
62+
use_decoupled_wd: bool = True,
6363
momentum: float = 0.9,
6464
beta_lip: float = 0.9,
6565
precond_lr: float = 0.1,
@@ -69,18 +69,18 @@ def __init__(
6969
warmup_steps: int = 10000,
7070
max_update_rms: float = 0.0,
7171
) -> None:
72+
self.use_decoupled_wd = use_decoupled_wd
73+
self.max_update_rms = max_update_rms
74+
self.precond_init_scale = precond_init_scale
75+
self.damping_noise_scale = damping_noise_scale
76+
self.warmup_steps = warmup_steps
7277
defaults = {
7378
"lr": lr,
7479
"beta_lip": beta_lip,
7580
"weight_decay": weight_decay,
76-
"use_decoupled_weight_decay": use_decoupled_weight_decay,
7781
"momentum": momentum,
7882
"precond_lr": precond_lr,
79-
"precond_init_scale": precond_init_scale,
80-
"max_update_rms": max_update_rms,
8183
"min_precond_lr": min_precond_lr,
82-
"warmup_steps": warmup_steps,
83-
"damping_noise_scale": damping_noise_scale,
8484
}
8585
super().__init__(params, defaults)
8686

@@ -114,12 +114,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
114114
if "Q" not in state or "L" not in state:
115115
state["Q"], state["L"] = _init_psgd_kron_states(
116116
grad,
117-
precond_init_scale=group["precond_init_scale"],
117+
precond_init_scale=self.precond_init_scale,
118118
)
119119

120120
# weight decay
121121
if group["weight_decay"] > 0.0:
122-
if group["use_decoupled_weight_decay"]:
122+
if self.use_decoupled_wd:
123123
# Apply decoupled weight decay
124124
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
125125
else:
@@ -131,21 +131,20 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
131131
exp_avg.lerp_(grad, 1 - group["momentum"])
132132

133133
# Get hyperparameters for preconditioner update
134-
damping_noise_scale = group["damping_noise_scale"]
135134
precond_lr = _get_precond_lr(
136-
group["precond_lr"], state["step"], group["min_precond_lr"], group["warmup_steps"]
135+
group["precond_lr"], state["step"], group["min_precond_lr"], self.warmup_steps
137136
)
138137

139138
beta_lip = group["beta_lip"]
140139
# Preconditioner update
141140
state["Q"], state["L"] = _update_precond_procrustes(
142-
state["Q"], state["L"], exp_avg, damping_noise_scale, precond_lr, beta_lip
141+
state["Q"], state["L"], exp_avg, self.damping_noise_scale, precond_lr, beta_lip
143142
)
144143
uniformize_q_in_place(state["Q"])
145144

146145
# Get weight update by preconditioning the momentum
147146
update = apply_preconditioner(state["Q"], exp_avg)
148-
_clip_update_rms_in_place(update, group["max_update_rms"])
147+
_clip_update_rms_in_place(update, self.max_update_rms)
149148

150149
# Apply weight update
151150
p.add_(update, alpha=-group["lr"])

0 commit comments

Comments
 (0)