Skip to content

Commit 4afbc39

Browse files
committed
address PR comments
Signed-off-by: mikail <mkhona@nvidia.com>
1 parent 373a545 commit 4afbc39

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,11 @@ def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1
4141
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
4242
result = (sigma_min + sigma_max) * OX
4343
identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype)
44-
# Pre-allocate tensors for memory efficiency
45-
A = torch.empty_like(identity_matrix)
46-
B = torch.empty_like(X)
4744
for s, sign in zip([sigma_min, sigma_max], [1, -1]):
48-
torch.add(s * identity_matrix, OX @ X.T, alpha=-1, out=A)
49-
torch.add(s * OX, X, alpha=-1, out=B)
50-
torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B, out=result)
51-
result = (1 / 2) * result
45+
A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1)
46+
B = torch.add(s * OX, X, alpha=-1)
47+
result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B)
48+
result = result * 0.5
5249

5350
if needs_transpose:
5451
result = result.T
@@ -73,10 +70,9 @@ def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
7370
X = X.T
7471
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
7572
aX = torch.add(beta * OX, X, alpha=-1)
76-
result = torch.empty_like(X)
77-
torch.add(beta * OX, X, out=result)
78-
torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1, out=result)
79-
result = (1 / 2) * result
73+
result = torch.add(beta * OX, X)
74+
result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1)
75+
result = result * 0.5
8076
if needs_transpose:
8177
result = result.T
8278
return result

0 commit comments

Comments
 (0)