@@ -41,14 +41,11 @@ def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1
4141 OX = newton_schulz (X , steps = 8 , coefficient_type = "polar_express" )
4242 result = (sigma_min + sigma_max ) * OX
4343 identity_matrix = torch .eye (X .shape [0 ], device = X .device , dtype = X .dtype )
44- # Pre-allocate tensors for memory efficiency
45- A = torch .empty_like (identity_matrix )
46- B = torch .empty_like (X )
4744 for s , sign in zip ([sigma_min , sigma_max ], [1 , - 1 ]):
48- torch .add (s * identity_matrix , OX @ X .T , alpha = - 1 , out = A )
49- torch .add (s * OX , X , alpha = - 1 , out = B )
50- torch .add (result , sign * newton_schulz (A , steps = 8 , coefficient_type = "polar_express" ) @ B , out = result )
51- result = ( 1 / 2 ) * result
45+ A = torch .add (s * identity_matrix , OX @ X .T , alpha = - 1 )
46+ B = torch .add (s * OX , X , alpha = - 1 )
47+ result = torch .add (result , sign * newton_schulz (A , steps = 8 , coefficient_type = "polar_express" ) @ B )
48+ result = result * 0.5
5249
5350 if needs_transpose :
5451 result = result .T
@@ -73,10 +70,9 @@ def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
7370 X = X .T
7471 OX = newton_schulz (X , steps = 8 , coefficient_type = "polar_express" )
7572 aX = torch .add (beta * OX , X , alpha = - 1 )
76- result = torch .empty_like (X )
77- torch .add (beta * OX , X , out = result )
78- torch .add (result , aX @ newton_schulz (aX , steps = 8 , coefficient_type = "polar_express" ).T @ OX , alpha = - 1 , out = result )
79- result = (1 / 2 ) * result
73+ result = torch .add (beta * OX , X )
74+ result = torch .add (result , aX @ newton_schulz (aX , steps = 8 , coefficient_type = "polar_express" ).T @ OX , alpha = - 1 )
75+ result = result * 0.5
8076 if needs_transpose :
8177 result = result .T
8278 return result
0 commit comments