Skip to content

Commit 4c9ac16

Browse files
committed
higher psgd precision
1 parent bff2265 commit 4c9ac16

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

heavyball/utils.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,7 +2363,7 @@ def init_Q_exprs(
23632363

23642364
@decorator_knowngood
23652365
def psgd_balance_Q(Q):
2366-
norms = [promote(q.norm(float("inf"))).log() for q in Q]
2366+
norms = [promote(q.abs().max()).log() for q in Q]
23672367
geometric_mean = sum([n for n in norms]) / len(Q)
23682368
for q, n in zip(Q, norms):
23692369
q *= (geometric_mean - n).exp()
@@ -2726,7 +2726,7 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
27262726
Adapted from @evanatyourservice
27272727
"""
27282728
if max_abs is None:
2729-
max_abs = A.norm(float("inf")).clamp(min=1e-8)
2729+
max_abs = A.abs().max().clamp(min=1e-8)
27302730

27312731
# cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
27322732
k = 2 ** math.ceil(math.log2(math.log2(min(A.shape)))) # next-largest-power-of-2 of log2-of-size
@@ -2794,15 +2794,16 @@ def _mv(x):
27942794

27952795
@decorator_knowngood
27962796
def procrustes_step(Q: Tensor, max_step_size: float = 1 / 8) -> None:
2797-
R = (Q.T - Q).contiguous()
2797+
Q_ = promote(Q)
2798+
R = (Q_.T - Q_).contiguous()
27982799
R_norm = max_singular_value(R, power_iter=2) + torch.finfo(R.dtype).smallest_normal
27992800
R = R / R_norm
2800-
RQ = R @ Q
2801+
RQ = R @ Q_
28012802
RRQ = R @ RQ
28022803
tr_RQ = RQ.diagonal().sum()
28032804
tr_RRQ = RRQ.diagonal().sum()
28042805
a = torch.where(tr_RRQ < 0, torch.clamp(-tr_RQ / tr_RRQ, max=max_step_size), max_step_size)
2805-
Q.add_(a * (RQ + 0.5 * a * RRQ))
2806+
copy_stochastic_(Q, Q_ + a * (RQ + 0.5 * a * RRQ))
28062807

28072808

28082809
@decorator_knowngood
@@ -3031,7 +3032,7 @@ def psgd_pro_update_precond(
30313032

30323033
total_numel = G.numel()
30333034
for q, exprG, lb_state in zip(Q, exprGs, running_lower_bound):
3034-
term1 = promote(compiled_einsum(exprG, Pg, Pg))
3035+
term1 = compiled_einsum(exprG, Pg, Pg)
30353036
q_ = promote(q)
30363037

30373038
if q.ndim < 2:
@@ -3159,7 +3160,7 @@ def _psgd_precond_update_(
31593160

31603161
q = promote(oq)
31613162
if update.ndim < 2:
3162-
lb = update.norm(float("inf"))
3163+
lb = update.abs().max()
31633164
else:
31643165
lb = max_singular_value(update, power_iter=power_iter)
31653166
update = promote(update)
@@ -3543,25 +3544,21 @@ def precond_grad_cached_(
35433544
cached_q: List[Tensor],
35443545
caution: bool = False,
35453546
grad: Optional[Tensor] = None,
3546-
cast: bool = True,
35473547
):
35483548
if caution:
35493549
ea = _compilable_cautioning(grad, ea)
35503550
args = [promote(q) for q in cached_q]
35513551
args = args + [promote(ea)]
35523552
expr = cached_precond_grad_expr(ndim_tuple(cached_q), ea.ndim)
3553-
new = compiled_einsum(expr, *args)
3554-
if cast:
3555-
return new.to(ea.dtype)
3556-
return new
3553+
return compiled_einsum(expr, *args)
35573554

35583555

35593556
TriuOrLine = Union[List[Tensor], List[Tuple[Optional[List[int]], Tensor]]]
35603557

35613558

35623559
@decorator_knowngood
35633560
def _compilable_fused_precond_grad_cached_(ea: Tensor, param, lr, grad, decay, caution, cached_q: List[Tensor]):
3564-
precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad, cast=False)
3561+
precond = precond_grad_cached_(ea, cached_q, caution=caution, grad=grad)
35653562
update_param_(param, precond, lr, decay, caution=False)
35663563

35673564

@@ -3589,18 +3586,14 @@ def psgd_precond_grad(
35893586
grad: Optional[Tensor] = None,
35903587
store_triu_as_line: bool = False,
35913588
symmetric_output: bool = False,
3592-
cast: bool = True,
35933589
):
35943590
if caution:
35953591
ea = _compilable_cautioning(grad, ea)
35963592
if store_triu_as_line:
35973593
preconds = line_to_triu(preconds, symmetric_output)
35983594
args = [promote(q) for q in preconds]
35993595
expr = precond_grad_expr(ndim_tuple(args), ea.ndim)
3600-
new = compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
3601-
if cast:
3602-
return new.to(ea.dtype)
3603-
return new
3596+
return compiled_einsum(expr, *[a for a in args for _ in (0, 1)], promote(ea))
36043597

36053598

36063599
@decorator_knowngood
@@ -3622,7 +3615,6 @@ def _compilable_fused_psgd_precond_grad(
36223615
grad=grad,
36233616
store_triu_as_line=store_triu_as_line,
36243617
symmetric_output=symmetric_output,
3625-
cast=False,
36263618
)
36273619
update_param_(param, precond, lr, decay, caution=False, grad=grad)
36283620

0 commit comments

Comments
 (0)