@@ -2525,9 +2525,10 @@ def lra_precond(U: Tensor, V: Tensor, d: Tensor, g: Tensor):
25252525
25262526@decorator_knowngood
25272527def dampen_grad (g : Tensor , damp : float = 2 ** - 13 ):
2528- # https://github.com/lixilinx/psgd_torch/blob/1943e66596111e78157ca1b72b31c1dfdf0653ef/preconditioned_stochastic_gradient_descent .py#L50
2528+ # https://github.com/lixilinx/psgd_torch/blob/89b4cead31b7ad1494c4cf4dc39f4cbf920ff14d/psgd .py
25292529 v = torch .randn_like (g )
2530- return v , g + damp * g .abs ().mean () * v
2530+ damping = damp + torch .finfo (g .dtype ).eps * g .abs ()
2531+ return v , g + damping * v
25312532
25322533
25332534@decorator_knowngood
@@ -2768,6 +2769,44 @@ def max_singular_value(A: Tensor, max_svd: int = 0, use_cholesky: bool = False,
27682769 return max_singular_value_power_iter (A , None , iterations = power_iter )
27692770
27702771
2772+ @decorator_knowngood
2773+ def max_eigenvalue_spd (A_outer : Tensor , power_iter : int = 4 ) -> Tensor :
2774+ """Power iteration for the largest eigenvalue of a symmetric positive (semi)definite matrix.
2775+ Exploits A = A^T: A^T A = A^2, so v -> A^T(Av) = v -> A(Av), saving a transpose.
2776+ Uses x @ A.mT (gemm transB=true) for faster BLAS dispatch than A.mv(x)."""
2777+ if A_outer .ndim < 2 :
2778+ return A_outer .max ()
2779+ x_norm , max_idx = A_outer .norm (dim = 1 ).max (dim = 0 )
2780+ x_norm = promote (x_norm )
2781+
2782+ def _inner ():
2783+ x = A_outer .index_select (0 , max_idx ).flatten ().contiguous ()
2784+ A = stochastic_round_ (A_outer / x_norm )
2785+ x = x / x_norm
2786+
2787+ def _mv (x ):
2788+ return promote ((x .to (A .dtype ) @ A .mT ) @ A .mT )
2789+
2790+ for _ in range (power_iter ):
2791+ x = F .normalize (_mv (x ), dim = 0 )
2792+ return (x @ _mv (x )).to (x_norm .dtype ).sqrt () * x_norm
2793+
2794+ return cond (x_norm > 0 , _inner , lambda : x_norm .squeeze ().clone ()).squeeze ()
2795+
2796+
2797+ @decorator_knowngood
2798+ def procrustes_step (Q : Tensor , max_step_size : float = 1 / 8 ) -> None :
2799+ R = (Q .T - Q ).contiguous ()
2800+ R_norm = max_singular_value (R , power_iter = 2 ) + torch .finfo (R .dtype ).smallest_normal
2801+ R = R / R_norm
2802+ RQ = R @ Q
2803+ RRQ = R @ RQ
2804+ tr_RQ = RQ .diagonal ().sum ()
2805+ tr_RRQ = RRQ .diagonal ().sum ()
2806+ a = torch .where (tr_RRQ < 0 , torch .clamp (- tr_RQ / tr_RRQ , max = max_step_size ), max_step_size )
2807+ Q .add_ (a * (RQ + 0.5 * a * RRQ ))
2808+
2809+
27712810@decorator_knowngood
27722811def clamped_max_singular_value (
27732812 A : Tensor , min : float , max_svd : int = 0 , use_cholesky : bool = False , power_iter : int = 16
@@ -2927,22 +2966,11 @@ def _chebychef_coeff(degree: int, device, eps: float = 1e-8):
29272966 return coeff0 .float (), coeffs .float ()
29282967
29292968
2930- @decorator_knowngood
2931- def _psgd_default_preconditioner_grad (
2932- terms : List [Tuple [Tensor , Tensor ]],
2933- Q : List [Tensor ],
2934- ) -> List [Tensor ]:
2935- out = []
2936- for q , (x , y ) in zip (Q , terms ):
2937- x = promote (x )
2938- y = promote (y )
2939- update = x - y
2940- if q .ndim < 2 :
2941- update = promote (q ) * update
2942- else :
2943- update = (promote (q ) @ update ).triu ()
2944- out .append (update )
2945- return out
2969+ def _update_lb (ell : Tensor , lb_state : Tensor , beta : Tensor ) -> Tensor :
2970+ ell = promote (ell )
2971+ ell = ell .maximum (promote (lb_state ) + (ell - promote (lb_state )) * (1 - beta ))
2972+ copy_stochastic_ (lb_state , ell )
2973+ return ell
29462974
29472975
29482976@decorator
@@ -2965,15 +2993,61 @@ def psgd_update_precond(
29652993 precond_lr , beta2 , lower_bount_beta = scalar_guard (precond_lr , beta2 , lower_bount_beta , G )
29662994
29672995 A , conjB = psgd_calc_A_and_conjB (G , Q , V )
2968- terms = [(compiled_einsum (exprG , A , A ), compiled_einsum (exprG , conjB , conjB )) for exprG in exprGs ]
2969- del A , conjB , V
2970- updates = _psgd_default_preconditioner_grad (terms , Q )
2971- _psgd_precond_update_ (
2972- updates , oq , running_lower_bound , lower_bount_beta , precond_lr , store_triu_as_line , power_iter
2973- )
2996+ del V
2997+
2998+ for oq_i , q , exprG , lb_state in zip (oq , Q , exprGs , running_lower_bound ):
2999+ term1 = promote (compiled_einsum (exprG , A , A ))
3000+ term2 = promote (compiled_einsum (exprG , conjB , conjB ))
3001+
3002+ if q .ndim < 2 :
3003+ ell = _update_lb ((term1 + term2 ).max (), lb_state , lower_bount_beta )
3004+ update = promote (q ) * (term1 - term2 )
3005+ else :
3006+ ell = _update_lb (max_eigenvalue_spd (term1 + term2 , power_iter = power_iter ), lb_state , lower_bount_beta )
3007+ update = (term1 - term2 ).triu () @ promote (q )
3008+ if store_triu_as_line :
3009+ update = triu_to_line ([update ])[0 ][1 ]
3010+
3011+ real_oq = oq_i [1 ] if isinstance (oq_i , tuple ) else oq_i
3012+ copy_stochastic_ (real_oq , promote (real_oq ) - update / ell * precond_lr )
29743013 return None
29753014
29763015
3016+ @decorator
3017+ def psgd_pro_update_precond (
3018+ G : Tensor ,
3019+ precond_lr : float ,
3020+ Q : List [Tensor ],
3021+ running_lower_bound : List [Tensor ],
3022+ lower_bount_beta : float ,
3023+ power_iter : int ,
3024+ dampening : float ,
3025+ ) -> None :
3026+ """Update Kronecker product preconditioner Q with Q0.5EQ1.5 (PRO) method."""
3027+ psgd_balance_Q (Q )
3028+ exprGs = calcG_expr (ndim_tuple (Q ), G .ndim )
3029+ precond_lr , lower_bount_beta = scalar_guard (precond_lr , lower_bount_beta , G )
3030+
3031+ damping = dampening + torch .finfo (G .dtype ).eps * G .abs ()
3032+ Pg = psgd_precond_grad (G + damping * torch .randn_like (G ), Q )
3033+
3034+ total_numel = G .numel ()
3035+ for q , exprG , lb_state in zip (Q , exprGs , running_lower_bound ):
3036+ term1 = promote (compiled_einsum (exprG , Pg , Pg ))
3037+ q_ = promote (q )
3038+
3039+ if q .ndim < 2 :
3040+ term2 = total_numel / max (1 , q .numel ())
3041+ ell = _update_lb (term1 .max () + term2 , lb_state , lower_bount_beta )
3042+ copy_stochastic_ (q , q_ - q_ * (term1 - term2 ) / ell * precond_lr )
3043+ else :
3044+ term2 = total_numel / q .shape [0 ]
3045+ ell = _update_lb (max_eigenvalue_spd (term1 , power_iter = power_iter ) + term2 , lb_state , lower_bount_beta )
3046+ copy_stochastic_ (q , q_ - (term1 @ q_ - term2 * q_ ) / ell * precond_lr )
3047+ procrustes_step (q )
3048+ del Pg
3049+
3050+
29773051@decorator_knowngood
29783052def bf16_matmul (x : Tensor , y : Tensor ):
29793053 return (promote (x ) @ promote (y )).to (x .dtype )
0 commit comments