Skip to content

Commit fe42f31

Browse files
committed
missed addmm
1 parent 64c9aa9 commit fe42f31

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1
4444
for s, sign in zip([sigma_min, sigma_max], [1, -1]):
4545
A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1)
4646
B = torch.add(s * OX, X, alpha=-1)
47-
result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B)
47+
result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign)
4848
result = result * 0.5
4949

5050
if needs_transpose:

0 commit comments

Comments
 (0)