@@ -2363,7 +2363,7 @@ def init_Q_exprs(
23632363
23642364@decorator_knowngood
23652365def psgd_balance_Q (Q ):
2366- norms = [promote (q .norm ( float ( "inf" ) )).log () for q in Q ]
2366+ norms = [promote (q .abs (). max ( )).log () for q in Q ]
23672367 geometric_mean = sum ([n for n in norms ]) / len (Q )
23682368 for q , n in zip (Q , norms ):
23692369 q *= (geometric_mean - n ).exp ()
@@ -2726,7 +2726,7 @@ def max_singular_value_cholesky(A: Tensor, max_abs: Optional[Tensor] = None):
27262726 Adapted from @evanatyourservice
27272727 """
27282728 if max_abs is None :
2729- max_abs = A .norm ( float ( "inf" ) ).clamp (min = 1e-8 )
2729+ max_abs = A .abs (). max ( ).clamp (min = 1e-8 )
27302730
27312731 # cholesky uses random projection, but this uses topk -- topk is a warm start, which may converge to a biased result
27322732 k = 2 ** math .ceil (math .log2 (math .log2 (min (A .shape )))) # next-largest-power-of-2 of log2-of-size
@@ -2794,15 +2794,16 @@ def _mv(x):
27942794
27952795@decorator_knowngood
27962796def procrustes_step (Q : Tensor , max_step_size : float = 1 / 8 ) -> None :
2797- R = (Q .T - Q ).contiguous ()
2797+ Q_ = promote (Q )
2798+ R = (Q_ .T - Q_ ).contiguous ()
27982799 R_norm = max_singular_value (R , power_iter = 2 ) + torch .finfo (R .dtype ).smallest_normal
27992800 R = R / R_norm
2800- RQ = R @ Q
2801+ RQ = R @ Q_
28012802 RRQ = R @ RQ
28022803 tr_RQ = RQ .diagonal ().sum ()
28032804 tr_RRQ = RRQ .diagonal ().sum ()
28042805 a = torch .where (tr_RRQ < 0 , torch .clamp (- tr_RQ / tr_RRQ , max = max_step_size ), max_step_size )
2805- Q . add_ ( a * (RQ + 0.5 * a * RRQ ))
2806+ copy_stochastic_ ( Q , Q_ + a * (RQ + 0.5 * a * RRQ ))
28062807
28072808
28082809@decorator_knowngood
@@ -3031,7 +3032,7 @@ def psgd_pro_update_precond(
30313032
30323033 total_numel = G .numel ()
30333034 for q , exprG , lb_state in zip (Q , exprGs , running_lower_bound ):
3034- term1 = promote ( compiled_einsum (exprG , Pg , Pg ) )
3035+ term1 = compiled_einsum (exprG , Pg , Pg )
30353036 q_ = promote (q )
30363037
30373038 if q .ndim < 2 :
@@ -3159,7 +3160,7 @@ def _psgd_precond_update_(
31593160
31603161 q = promote (oq )
31613162 if update .ndim < 2 :
3162- lb = update .norm ( float ( "inf" ) )
3163+ lb = update .abs (). max ( )
31633164 else :
31643165 lb = max_singular_value (update , power_iter = power_iter )
31653166 update = promote (update )
@@ -3543,25 +3544,21 @@ def precond_grad_cached_(
35433544 cached_q : List [Tensor ],
35443545 caution : bool = False ,
35453546 grad : Optional [Tensor ] = None ,
3546- cast : bool = True ,
35473547):
35483548 if caution :
35493549 ea = _compilable_cautioning (grad , ea )
35503550 args = [promote (q ) for q in cached_q ]
35513551 args = args + [promote (ea )]
35523552 expr = cached_precond_grad_expr (ndim_tuple (cached_q ), ea .ndim )
3553- new = compiled_einsum (expr , * args )
3554- if cast :
3555- return new .to (ea .dtype )
3556- return new
3553+ return compiled_einsum (expr , * args )
35573554
35583555
35593556TriuOrLine = Union [List [Tensor ], List [Tuple [Optional [List [int ]], Tensor ]]]
35603557
35613558
35623559@decorator_knowngood
35633560def _compilable_fused_precond_grad_cached_ (ea : Tensor , param , lr , grad , decay , caution , cached_q : List [Tensor ]):
3564- precond = precond_grad_cached_ (ea , cached_q , caution = caution , grad = grad , cast = False )
3561+ precond = precond_grad_cached_ (ea , cached_q , caution = caution , grad = grad )
35653562 update_param_ (param , precond , lr , decay , caution = False )
35663563
35673564
@@ -3589,18 +3586,14 @@ def psgd_precond_grad(
35893586 grad : Optional [Tensor ] = None ,
35903587 store_triu_as_line : bool = False ,
35913588 symmetric_output : bool = False ,
3592- cast : bool = True ,
35933589):
35943590 if caution :
35953591 ea = _compilable_cautioning (grad , ea )
35963592 if store_triu_as_line :
35973593 preconds = line_to_triu (preconds , symmetric_output )
35983594 args = [promote (q ) for q in preconds ]
35993595 expr = precond_grad_expr (ndim_tuple (args ), ea .ndim )
3600- new = compiled_einsum (expr , * [a for a in args for _ in (0 , 1 )], promote (ea ))
3601- if cast :
3602- return new .to (ea .dtype )
3603- return new
3596+ return compiled_einsum (expr , * [a for a in args for _ in (0 , 1 )], promote (ea ))
36043597
36053598
36063599@decorator_knowngood
@@ -3622,7 +3615,6 @@ def _compilable_fused_psgd_precond_grad(
36223615 grad = grad ,
36233616 store_triu_as_line = store_triu_as_line ,
36243617 symmetric_output = symmetric_output ,
3625- cast = False ,
36263618 )
36273619 update_param_ (param , precond , lr , decay , caution = False , grad = grad )
36283620
0 commit comments