@@ -65,7 +65,7 @@ def __init__(
6565 precond_lr : float = 0.1 ,
6666 precond_init_scale : float = 1.0 ,
6767 damping_noise_scale : float = 0.1 ,
68- min_precond_lr : float = 0.3 ,
68+ min_precond_lr : float = 0.01 ,
6969 warmup_steps : int = 10000 ,
7070 max_update_rms : float = 0.0 ,
7171 ) -> None :
@@ -182,7 +182,7 @@ def _update_precond_procrustes(
182182 q_list : List [torch .Tensor ],
183183 lip_const_list : List [torch .Tensor ],
184184 exp_avg : torch .Tensor ,
185- damping_noise_scale : float ,
185+ damping_noise_scale : float = 1e-9 ,
186186 precond_lr : float = 0.1 ,
187187 beta_lip : float = 0.9 ,
188188) -> Tuple [List [torch .Tensor ], List [torch .Tensor ]]:
@@ -200,7 +200,10 @@ def _update_precond_procrustes(
200200 q_list: List of Kronecker factors.
201201 lip_const_list: List of Lipschitz constants for the Kronecker factors.
202202 """
203- pg = apply_preconditioner (q_list , torch .add (exp_avg , torch .randn_like (exp_avg ) * damping_noise_scale , alpha = 1.0 ))
203+ dampened_momentum = exp_avg + (
204+ damping_noise_scale + torch .finfo (exp_avg .dtype ).eps * exp_avg .abs ()
205+ ) * torch .randn_like (exp_avg )
206+ pg = apply_preconditioner (q_list , dampened_momentum )
204207 total_numel = pg .numel ()
205208 updated_q_list : List [torch .Tensor ] = []
206209 updated_lip_const_list : List [torch .Tensor ] = []
@@ -282,7 +285,7 @@ def _update_1d_preconditioner(
282285 return q , lip_const
283286
284287
285- def _get_precond_lr (precond_lr : float , step : int , min_precond_lr : float = 0.3 , warmup_steps : int = 10000 ) -> float :
288+ def _get_precond_lr (precond_lr : float , step : int , min_precond_lr : float = 0.01 , warmup_steps : int = 10000 ) -> float :
286289 r"""Helper function to get preconditioner learning rate for this optimization step based on a square root schedule.
287290
288291 Decaying from a higher lr down to min_precond_lr improves accuracy.
0 commit comments