@@ -42,7 +42,7 @@ class PSGDPro(torch.optim.Optimizer):
4242 params: Iterable of parameters to optimize or dicts defining parameter groups
4343 lr: The learning rate to use
4444 weight_decay: Weight decay coefficient
45- use_decoupled_weight_decay : Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
45+ use_decoupled_wd : Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization:
4646 https://arxiv.org/abs/1711.05101.
4747 momentum: Momentum coefficient for exponential moving average of gradient.
4848 beta_lip: EMA beta for the Lipschitz constants.
@@ -59,7 +59,7 @@ def __init__(
5959 params : ParamsT ,
6060 lr : float = 3e-3 ,
6161 weight_decay : float = 0.01 ,
62- use_decoupled_weight_decay : bool = True ,
62+ use_decoupled_wd : bool = True ,
6363 momentum : float = 0.9 ,
6464 beta_lip : float = 0.9 ,
6565 precond_lr : float = 0.1 ,
@@ -69,18 +69,18 @@ def __init__(
6969 warmup_steps : int = 10000 ,
7070 max_update_rms : float = 0.0 ,
7171 ) -> None :
72+ self .use_decoupled_wd = use_decoupled_wd
73+ self .max_update_rms = max_update_rms
74+ self .precond_init_scale = precond_init_scale
75+ self .damping_noise_scale = damping_noise_scale
76+ self .warmup_steps = warmup_steps
7277 defaults = {
7378 "lr" : lr ,
7479 "beta_lip" : beta_lip ,
7580 "weight_decay" : weight_decay ,
76- "use_decoupled_weight_decay" : use_decoupled_weight_decay ,
7781 "momentum" : momentum ,
7882 "precond_lr" : precond_lr ,
79- "precond_init_scale" : precond_init_scale ,
80- "max_update_rms" : max_update_rms ,
8183 "min_precond_lr" : min_precond_lr ,
82- "warmup_steps" : warmup_steps ,
83- "damping_noise_scale" : damping_noise_scale ,
8484 }
8585 super ().__init__ (params , defaults )
8686
@@ -114,12 +114,12 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
114114 if "Q" not in state or "L" not in state :
115115 state ["Q" ], state ["L" ] = _init_psgd_kron_states (
116116 grad ,
117- precond_init_scale = group [ " precond_init_scale" ] ,
117+ precond_init_scale = self . precond_init_scale ,
118118 )
119119
120120 # weight decay
121121 if group ["weight_decay" ] > 0.0 :
122- if group [ "use_decoupled_weight_decay" ] :
122+ if self . use_decoupled_wd :
123123 # Apply decoupled weight decay
124124 p .add_ (p , alpha = (- group ["lr" ] * group ["weight_decay" ]))
125125 else :
@@ -131,21 +131,20 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
131131 exp_avg .lerp_ (grad , 1 - group ["momentum" ])
132132
133133 # Get hyperparameters for preconditioner update
134- damping_noise_scale = group ["damping_noise_scale" ]
135134 precond_lr = _get_precond_lr (
136- group ["precond_lr" ], state ["step" ], group ["min_precond_lr" ], group [ " warmup_steps" ]
135+ group ["precond_lr" ], state ["step" ], group ["min_precond_lr" ], self . warmup_steps
137136 )
138137
139138 beta_lip = group ["beta_lip" ]
140139 # Preconditioner update
141140 state ["Q" ], state ["L" ] = _update_precond_procrustes (
142- state ["Q" ], state ["L" ], exp_avg , damping_noise_scale , precond_lr , beta_lip
141+ state ["Q" ], state ["L" ], exp_avg , self . damping_noise_scale , precond_lr , beta_lip
143142 )
144143 uniformize_q_in_place (state ["Q" ])
145144
146145 # Get weight update by preconditioning the momentum
147146 update = apply_preconditioner (state ["Q" ], exp_avg )
148- _clip_update_rms_in_place (update , group [ " max_update_rms" ] )
147+ _clip_update_rms_in_place (update , self . max_update_rms )
149148
150149 # Apply weight update
151150 p .add_ (update , alpha = - group ["lr" ])
0 commit comments