Skip to content

Commit 318cf5f

Browse files
committed
add PSGD PRO, fix existing PSGD
1 parent 50c49d4 commit 318cf5f

File tree

4 files changed

+291
-25
lines changed

4 files changed

+291
-25
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Muon, MuonLaProp, OrthoLaProp, LaPropOrtho
7070
SOAP, PaLMSOAP, PrecondScheduleSOAP, PrecondSchedulePaLMSOAP, SOAPNAdam, SOAPAdEMAMix, ForeachSOLP
7171

7272
**PSGD (Kronecker):**
73-
PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron
73+
PSGDPRO, PSGDKron, CachedPSGDKron, DelayedPSGD, CachedDelayedPSGDKron, PurePSGD, NewtonPSGDKron, NewtonHybrid2PSGDKron
7474

7575
`Newton`-PSGD requires a closure passed to `step()`.
7676

heavyball/__init__.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,82 @@ class NewtonHybrid2PSGDKron(ForeachCachedNewtonPSGD):
10231023
hvp_interval = 2
10241024

10251025

1026+
class PSGDPRO(C.BaseOpt):
1027+
"""
1028+
PSGD with Q0.5EQ1.5 (PRO/Procrustes) preconditioner update.
1029+
Solve-free alternative to standard PSGD-Kron (EQ method).
1030+
Reference: https://github.com/lixilinx/psgd_torch
1031+
"""
1032+
1033+
cached: bool = False
1034+
exp_avg_input: bool = True
1035+
1036+
def __init__(
1037+
self,
1038+
params,
1039+
lr=0.001,
1040+
beta=None,
1041+
betas=(0.9, 0.999),
1042+
weight_decay=0.0,
1043+
preconditioner_update_probability=C.use_default,
1044+
max_size_triangular=2048,
1045+
min_ndim_triangular=2,
1046+
memory_save_mode=None,
1047+
momentum_into_precond_update=True,
1048+
warmup_steps: int = 0,
1049+
merge_dims: bool = False,
1050+
split: bool = False,
1051+
foreach: bool = True,
1052+
q_dtype="float32",
1053+
stochastic_schedule: bool = False,
1054+
storage_dtype: str = "float32",
1055+
mars: bool = False,
1056+
caution: bool = False,
1057+
mars_gamma: float = 0.0025,
1058+
cached: Optional[bool] = C.use_default,
1059+
exp_avg_input: Optional[bool] = C.use_default,
1060+
gradient_clipping: C.str_or_fn = C.use_default,
1061+
update_clipping: C.str_or_fn = C.use_default,
1062+
precond_grad_accum: bool = False,
1063+
lower_bound_beta: float = 0.9,
1064+
dampening: float = 2**-13,
1065+
precond_update_power_iterations: int = 2,
1066+
precond_init_scale=None,
1067+
precond_init_scale_scale: float = 1,
1068+
precond_init_scale_power: Optional[float] = None,
1069+
precond_lr: float = 0.1,
1070+
compile_step: bool = C.use_default,
1071+
promote: bool = C.use_default,
1072+
ecc: str | None = None,
1073+
param_ecc: str | None = None,
1074+
**kwargs,
1075+
):
1076+
cached = C.default(cached, self.cached)
1077+
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
1078+
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
1079+
1080+
params, defaults = C._build_defaults(locals())
1081+
defaults["store_triu_as_line"] = False
1082+
defaults["inverse_free"] = False
1083+
1084+
self.precond_schedule = C.default(
1085+
defaults.pop("preconditioner_update_probability"), utils.precond_update_prob_schedule()
1086+
)
1087+
1088+
super().__init__(
1089+
params,
1090+
defaults,
1091+
foreach,
1092+
gradient_clipping,
1093+
update_clipping,
1094+
False,
1095+
fns=(
1096+
*(C.exp_avg,) * exp_avg_input,
1097+
functools.partial(C.scale_by_psgd_pro, cached=cached),
1098+
),
1099+
)
1100+
1101+
10261102
class ForeachPSGDLRA(C.BaseOpt):
10271103
"""
10281104
Originally from Evan Walters and Omead Pooladzandi, 2024

heavyball/chainable.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,27 @@ def _init_psgd_kron(state, group, update, grad, param, cached: bool = False, pro
792792
state["Q_cache"] = [torch.empty_like(q) for q in Q]
793793

794794

795+
def _init_psgd_pro_kron(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
796+
Q = utils.init_Q_exprs(
797+
grad,
798+
group["precond_init_scale"],
799+
group["precond_init_scale_scale"],
800+
group["precond_init_scale_power"],
801+
group["max_size_triangular"],
802+
group["min_ndim_triangular"],
803+
group["memory_save_mode"],
804+
None,
805+
None,
806+
dtype=getattr(torch, group["q_dtype"]),
807+
)
808+
state["Q"] = Q
809+
state["running_lower_bound"] = [torch.zeros((1,), device=q.device, dtype=torch.float64) for q in Q]
810+
state["step"] = torch.zeros((), device=param.device, dtype=torch.float64)
811+
if not cached:
812+
return
813+
state["Q_cache"] = [torch.empty_like(q) for q in Q]
814+
815+
795816
def _init_psgd_lra(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
796817
state["U"], state["V"], state["d"] = utils.init_lra(
797818
grad,
@@ -1094,6 +1115,56 @@ def _update_psgd_precond(
10941115
return None
10951116

10961117

1118+
def _update_psgd_pro_precond(
1119+
cached,
1120+
Q_cache,
1121+
group,
1122+
param,
1123+
grad,
1124+
Q,
1125+
running_lower_bound,
1126+
step,
1127+
prob: Optional[callable] = None,
1128+
) -> None:
1129+
if prob is None:
1130+
prob = utils.precond_update_prob_schedule()
1131+
1132+
if not group["is_preconditioning"]:
1133+
return
1134+
1135+
utils.psgd_pro_update_precond(
1136+
grad,
1137+
group["precond_lr"],
1138+
Q,
1139+
running_lower_bound,
1140+
group["lower_bound_beta"],
1141+
group["precond_update_power_iterations"],
1142+
group["dampening"],
1143+
)
1144+
1145+
if isinstance(prob, float):
1146+
float_prob = prob
1147+
else:
1148+
float_prob = prob(group["step"])
1149+
group["is_cached"] = should_use_cache = cached and float_prob < 0.5
1150+
1151+
if not should_use_cache or not cached:
1152+
return
1153+
1154+
for i, (c_, q_) in enumerate(zip(Q_cache, Q)):
1155+
if c_ is None:
1156+
c_ = (
1157+
torch.empty_like(q_)
1158+
if q_.ndim == 1
1159+
else torch.empty(q_.shape[0], q_.shape[0], device=q_.device, dtype=q_.dtype)
1160+
)
1161+
Q_cache[i] = c_
1162+
if q_.ndim == 2:
1163+
torch.matmul(q_.T, q_, out=c_)
1164+
else:
1165+
torch.mul(q_, q_, out=c_)
1166+
1167+
10971168
def _cached_psgd_precond_grad(group, update, Q, Q_cache, grad):
10981169
kwargs = {"ea": update, "caution": group["caution"], "grad": grad}
10991170
if group.get("is_cached", False) and Q_cache[0] is not None:
@@ -1297,6 +1368,51 @@ def update_by_delayed_psgd(
12971368
raise SkipUpdate from None
12981369

12991370

1371+
@needs_full_param
1372+
@SqueezeGrad
1373+
@PrecondGradAccumGuard
1374+
@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
1375+
@no_state_no_foreach
1376+
def scale_by_psgd_pro(
1377+
group,
1378+
update,
1379+
grad,
1380+
param,
1381+
update_to_precond,
1382+
Q,
1383+
Q_cache,
1384+
running_lower_bound: List[Tensor],
1385+
step: Tensor,
1386+
cached: bool = False,
1387+
prob: Optional[callable] = None,
1388+
):
1389+
_update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
1390+
return _cached_psgd_precond_grad(group, update, Q, Q_cache, grad)
1391+
1392+
1393+
@needs_full_param
1394+
@SqueezeGrad
1395+
@PrecondGradAccumGuard
1396+
@general_guard("Q", "Q_cache", "running_lower_bound", "step", init_fn=_init_psgd_pro_kron, skip_first=False)
1397+
@no_state_no_foreach
1398+
def update_by_psgd_pro(
1399+
group,
1400+
update,
1401+
grad,
1402+
param,
1403+
update_to_precond,
1404+
Q,
1405+
Q_cache,
1406+
running_lower_bound: List[Tensor],
1407+
step: Tensor,
1408+
cached: bool = False,
1409+
prob: Optional[callable] = None,
1410+
):
1411+
_update_psgd_pro_precond(cached, Q_cache, group, param, update_to_precond, Q, running_lower_bound, step, prob)
1412+
_fused_cached_psgd_precond_grad(group, update, param, update, Q, Q_cache)
1413+
raise SkipUpdate from None
1414+
1415+
13001416
def palm_beta2(state, group, update, grad, param):
13011417
beta2 = 1 - group["step"] ** -group["beta2_scale"]
13021418
group["betas"] = (utils.get_beta1(group), beta2)

heavyball/utils.py

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,9 +2525,10 @@ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
25252525

25262526
@decorator_knowngood
25272527
def dampen_grad(g: Tensor, damp: float = 2**-13):
2528-
# https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent.py#L50
2528+
# https://github.com/lixilinx/psgd_torch/blob/89b4cead31b7ad1494c4cf4dc39f4cbf920ff14d/psgd.py
25292529
v = torch.randn_like(g)
2530-
return v, g + damp * g.abs().mean() * v
2530+
damping = damp + torch.finfo(g.dtype).eps * g.abs()
2531+
return v, g + damping * v
25312532

25322533

25332534
@decorator_knowngood
@@ -2768,6 +2769,44 @@ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False,
27682769
return max_singular_value_power_iter(A, None, iterations=power_iter)
27692770

27702771

2772+
@decorator_knowngood
2773+
def max_eigenvalue_spd(A_outer: Tensor, power_iter: int = 4) -> Tensor:
2774+
"""Power iteration for the largest eigenvalue of a symmetric positive (semi)definite matrix.
2775+
Exploits A = A^T: A^T A = A^2, so v -> A^T(Av) = v -> A(Av), saving a transpose.
2776+
Uses x @ A.mT (gemm transB=true) for faster BLAS dispatch than A.mv(x)."""
2777+
if A_outer.ndim < 2:
2778+
return A_outer.max()
2779+
x_norm, max_idx = A_outer.norm(dim=1).max(dim=0)
2780+
x_norm = promote(x_norm)
2781+
2782+
def _inner():
2783+
x = A_outer.index_select(0, max_idx).flatten().contiguous()
2784+
A = stochastic_round_(A_outer / x_norm)
2785+
x = x / x_norm
2786+
2787+
def _mv(x):
2788+
return promote((x.to(A.dtype) @ A.mT) @ A.mT)
2789+
2790+
for _ in range(power_iter):
2791+
x = F.normalize(_mv(x), dim=0)
2792+
return (x @ _mv(x)).to(x_norm.dtype).sqrt() * x_norm
2793+
2794+
return cond(x_norm > 0, _inner, lambda: x_norm.squeeze().clone()).squeeze()
2795+
2796+
2797+
@decorator_knowngood
2798+
def procrustes_step(Q: Tensor, max_step_size: float = 1 / 8) -> None:
2799+
R = (Q.T - Q).contiguous()
2800+
R_norm = max_singular_value(R, power_iter=2) + torch.finfo(R.dtype).smallest_normal
2801+
R = R / R_norm
2802+
RQ = R @ Q
2803+
RRQ = R @ RQ
2804+
tr_RQ = RQ.diagonal().sum()
2805+
tr_RRQ = RRQ.diagonal().sum()
2806+
a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, max=max_step_size), max_step_size)
2807+
Q.add_(a * (RQ + 0.5 * a * RRQ))
2808+
2809+
27712810
@decorator_knowngood
27722811
def clamped_max_singular_value(
27732812
A: Tensor, min: float, max_svd: int = 0, use_cholesky: bool = False, power_iter: int = 16
@@ -2927,22 +2966,11 @@ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
29272966
return coeff0.float(), coeffs.float()
29282967

29292968

2930-
@decorator_knowngood
2931-
def _psgd_default_preconditioner_grad(
2932-
terms: List[Tuple[Tensor, Tensor]],
2933-
Q: List[Tensor],
2934-
) -> List[Tensor]:
2935-
out = []
2936-
for q, (x, y) in zip(Q, terms):
2937-
x = promote(x)
2938-
y = promote(y)
2939-
update = x - y
2940-
if q.ndim < 2:
2941-
update = promote(q) * update
2942-
else:
2943-
update = (promote(q) @ update).triu()
2944-
out.append(update)
2945-
return out
2969+
def _update_lb(ell: Tensor, lb_state: Tensor, beta: Tensor) -> Tensor:
2970+
ell = promote(ell)
2971+
ell = ell.maximum(promote(lb_state) + (ell - promote(lb_state)) * (1 - beta))
2972+
copy_stochastic_(lb_state, ell)
2973+
return ell
29462974

29472975

29482976
@decorator
@@ -2965,15 +2993,61 @@ def psgd_update_precond(
29652993
precond_lr, beta2, lower_bount_beta = scalar_guard(precond_lr, beta2, lower_bount_beta, G)
29662994

29672995
A, conjB = psgd_calc_A_and_conjB(G, Q, V)
2968-
terms = [(compiled_einsum(exprG, A, A), compiled_einsum(exprG, conjB, conjB)) for exprG in exprGs]
2969-
del A, conjB, V
2970-
updates = _psgd_default_preconditioner_grad(terms, Q)
2971-
_psgd_precond_update_(
2972-
updates, oq, running_lower_bound, lower_bount_beta, precond_lr, store_triu_as_line, power_iter
2973-
)
2996+
del V
2997+
2998+
for oq_i, q, exprG, lb_state in zip(oq, Q, exprGs, running_lower_bound):
2999+
term1 = promote(compiled_einsum(exprG, A, A))
3000+
term2 = promote(compiled_einsum(exprG, conjB, conjB))
3001+
3002+
if q.ndim < 2:
3003+
ell = _update_lb((term1 + term2).max(), lb_state, lower_bount_beta)
3004+
update = promote(q) * (term1 - term2)
3005+
else:
3006+
ell = _update_lb(max_eigenvalue_spd(term1 + term2, power_iter=power_iter), lb_state, lower_bount_beta)
3007+
update = (term1 - term2).triu() @ promote(q)
3008+
if store_triu_as_line:
3009+
update = triu_to_line([update])[0][1]
3010+
3011+
real_oq = oq_i[1] if isinstance(oq_i, tuple) else oq_i
3012+
copy_stochastic_(real_oq, promote(real_oq) - update / ell * precond_lr)
29743013
return None
29753014

29763015

3016+
@decorator
3017+
def psgd_pro_update_precond(
3018+
G: Tensor,
3019+
precond_lr: float,
3020+
Q: List[Tensor],
3021+
running_lower_bound: List[Tensor],
3022+
lower_bount_beta: float,
3023+
power_iter: int,
3024+
dampening: float,
3025+
) -> None:
3026+
"""Update Kronecker product preconditioner Q with Q0.5EQ1.5 (PRO) method."""
3027+
psgd_balance_Q(Q)
3028+
exprGs = calcG_expr(ndim_tuple(Q), G.ndim)
3029+
precond_lr, lower_bount_beta = scalar_guard(precond_lr, lower_bount_beta, G)
3030+
3031+
damping = dampening + torch.finfo(G.dtype).eps * G.abs()
3032+
Pg = psgd_precond_grad(G + damping * torch.randn_like(G), Q)
3033+
3034+
total_numel = G.numel()
3035+
for q, exprG, lb_state in zip(Q, exprGs, running_lower_bound):
3036+
term1 = promote(compiled_einsum(exprG, Pg, Pg))
3037+
q_ = promote(q)
3038+
3039+
if q.ndim < 2:
3040+
term2 = total_numel / max(1, q.numel())
3041+
ell = _update_lb(term1.max() + term2, lb_state, lower_bount_beta)
3042+
copy_stochastic_(q, q_ - q_ * (term1 - term2) / ell * precond_lr)
3043+
else:
3044+
term2 = total_numel / q.shape[0]
3045+
ell = _update_lb(max_eigenvalue_spd(term1, power_iter=power_iter) + term2, lb_state, lower_bount_beta)
3046+
copy_stochastic_(q, q_ - (term1 @ q_ - term2 * q_) / ell * precond_lr)
3047+
procrustes_step(q)
3048+
del Pg
3049+
3050+
29773051
@decorator_knowngood
29783052
def bf16_matmul(x: Tensor, y: Tensor):
29793053
return (promote(x) @ promote(y)).to(x.dtype)

0 commit comments

Comments
 (0)