@@ -115,16 +115,6 @@ def div_ema(base: Tensor, eps: float, base_sq: Tensor, update_sq: Tensor, beta_s
115115 return base / stable_sqrt (ema_ (base_sq , update_sq , beta_sq , step ), eps )
116116
117117
118- def decay_weight_ (state : Dict [str , Any ], param : torch .nn .Parameter , group : Dict [str , Any ]):
119- if group ["decay_to_init" ]:
120- if "param_at_init" not in state :
121- state ["param_at_init" ] = torch .clone (param .detach ())
122- else :
123- param .add_ (state ["param_at_init" ] - param , alpha = group ["weight_decay" ] * group ["lr" ])
124- else :
125- param .mul_ (1 - group ["weight_decay" ] * group ["lr" ])
126-
127-
128118def _default_decay (weight_decay_cls : Optional [WeightDecayChain ]) -> WeightDecayChain :
129119 if weight_decay_cls is None :
130120 return WeightDecayChain (L2WeightDecay ())
@@ -153,16 +143,15 @@ def step(self, closure=None):
153143 if closure is not None :
154144 loss = closure ()
155145
156- self .weight_decay_cls (self )
157-
158146 for group in self .param_groups :
159147 for p in group ['params' ]:
160148 state = self .state [p ]
161149 if "lr" in state :
162150 group ["lr" ] = state ["lr" ]
163- decay_weight_ (state , p , group )
164151 state ["param" ] = torch .clone (p .detach ())
165152
153+ self .weight_decay_cls (self )
154+
166155 self .inner_optimizer .step ()
167156
168157 for group in self .inner_optimizer .param_groups :
@@ -330,8 +319,6 @@ def step(self, closure=None):
330319 if do_base or group ["graft" ]:
331320 for s in self .base_statistics :
332321 state [s ] = torch .zeros_like (p , memory_format = torch .preserve_format )
333- if group ["decay_to_init" ]:
334- state ["init" ] = torch .clone (p .detach ())
335322
336323 step_t = state ['step' ]
337324 step_t += 1
@@ -393,7 +380,7 @@ class TGLaProp(TrueGrad):
393380
394381 def __init__ (self , params , lr : float = 1e-3 ,
395382 betas : Union [Tuple [float , float ], Tuple [float , float , float , float ]] = (0.9 , 0.99 ), eps : float = 1e-12 ,
396- weight_decay : float = 1e-2 , graft : bool = True , decay_to_init : bool = False ,
383+ weight_decay : float = 1e-2 , graft : bool = True ,
397384 default_to_baseline : bool = False , enforce_baseline : bool = False ,
398385 weight_decay_cls : Optional [WeightDecayChain ] = None ):
399386 super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
0 commit comments