Skip to content

Commit 4f7d2c2

Browse files
committed
added missing torch addmm with mm
1 parent a170f44 commit 4f7d2c2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
7171
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
7272
aX = torch.add(beta * OX, X, alpha=-1)
7373
result = torch.add(beta * OX, X)
74-
result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1)
74+
result = torch.addmm(
75+
result, aX, torch.mm(newton_schulz(aX, steps=8, coefficient_type="polar_express").T, OX), alpha=-1
76+
)
7577
result = result * 0.5
7678
if needs_transpose:
7779
result = result.T

0 commit comments

Comments
 (0)