1+ import functools
12import warnings
23from typing import Tuple , Union , List , Dict , Any , Optional
34
67from torch .nn import Parameter
78
89
10+ class WeightDecayBase :
11+ def __init__ (self ):
12+ pass
13+
14+ def __call__ (self , mod : torch .optim .Optimizer , p : torch .Tensor , idx : int ):
15+ return p
16+
17+
18+ class WeightDecayChain :
19+ def __init__ (self , * operands : WeightDecayBase ):
20+ self .operands = operands
21+
22+ def __call__ (self , mod : torch .optim .Optimizer ):
23+ idx = 0
24+ for group in mod .param_groups :
25+ for p in group ["params" ]:
26+ p .data .add (functools .reduce (lambda x , f : f (mod , x , idx ), self .operands , p ),
27+ alpha = - group ["lr" ] * group ["weight_decay" ])
28+ idx += 1
29+
30+
31+ class LpWeightDecay (WeightDecayBase ):
32+ def __init__ (self , power : float ):
33+ self .power = power
34+
35+ def __call__ (self , mod : torch .optim .Optimizer , p : Tensor , idx : int ):
36+ return p .abs ().pow (self .power ) * p .sign ()
37+
38+
39+ class L1WeightDecay (LpWeightDecay ):
40+ def __init__ (self ):
41+ super ().__init__ (0 )
42+
43+
44+ class L2WeightDecay (LpWeightDecay ):
45+ def __init__ (self ):
46+ super ().__init__ (1 )
47+
48+
49+ def _param_iterator (mod : torch .optim .Optimizer ):
50+ yield from (p .detach ().clone () for group in mod .param_groups for p in group ["params" ])
51+
52+
53+ class WeightDecayToValue (WeightDecayBase ):
54+ def __init__ (self ):
55+ self .target_values : List [Tensor ] = ...
56+ self .global_step = 0
57+
58+ def _on_step_start (self , mod : torch .optim .Optimizer ):
59+ pass
60+
61+ def _on_global_start (self , mod : torch .optim .Optimizer ):
62+ pass
63+
64+ def _preprocess (self , target : Tensor ):
65+ return target
66+
67+ def __call__ (self , mod : torch .optim .Optimizer , p : Tensor , idx : int ):
68+ if idx == 0 :
69+ if self .global_step == 0 :
70+ self ._on_global_start (mod )
71+ self ._on_step_start (mod )
72+ self .global_step += 1
73+ return p - self ._preprocess (self .target_values [idx ])
74+
75+
76+ class WeightDecayToInit (WeightDecayToValue ):
77+ def _on_global_start (self , mod : torch .optim .Optimizer ):
78+ self .target_values = list (_param_iterator (mod ))
79+
80+
81+ class WeightDecayToEMA (WeightDecayToInit ):
82+ def __init__ (self , beta : float = 0.999 ):
83+ super ().__init__ ()
84+ self .beta = beta
85+
86+ def _on_global_start (self , mod : torch .optim .Optimizer ):
87+ self .target_values = [torch .zeros_like (x ) for x in _param_iterator (mod )]
88+
89+ def _on_step_start (self , mod : torch .optim .Optimizer ):
90+ self .global_step += 1
91+ for v , p in zip (self .target_values , _param_iterator (mod )):
92+ v .mul_ (self .beta ).add_ (p , alpha = 1 - self .beta )
93+
94+ def _preprocess (self , target : Tensor ):
95+ return target / (1 - self .beta ** self .global_step )
96+
97+
998def ema_ (base : Tensor , update : Tensor , beta : float , step : Optional [int ] = None ):
1099 base .mul_ (beta ).add_ (update , alpha = 1 - beta )
11100 if step is None :
@@ -31,12 +120,18 @@ def decay_weight_(state: Dict[str, Any], param: torch.nn.Parameter, group: Dict[
31120 param .mul_ (1 - group ["weight_decay" ] * group ["lr" ])
32121
33122
123+ def _default_decay (weight_decay_cls : Optional [WeightDecayChain ]) -> WeightDecayChain :
124+ if weight_decay_cls is None :
125+ return WeightDecayChain (L2WeightDecay ())
126+ return weight_decay_cls
127+
128+
34129class OptimizerOptimizer (torch .optim .Optimizer ):
35130 def __init__ (self , params , inner_optimizer : torch .optim .Optimizer , learning_rate_learning_rate : float = 1 ,
36- weight_decay : float = 0 , decay_to_init : bool = False ):
37- self .learning_rate_learning_rate = learning_rate_learning_rate
38-
131+ weight_decay : float = 0 , weight_decay_cls : Optional [WeightDecayChain ] = None ):
39132 self .inner_optimizer = inner_optimizer
133+ self .learning_rate_learning_rate = learning_rate_learning_rate
134+ self .weight_decay_cls = _default_decay (weight_decay_cls )
40135 param_groups = self .inner_optimizer .param_groups
41136 self .inner_optimizer .param_groups = []
42137 for group in param_groups :
@@ -45,14 +140,16 @@ def __init__(self, params, inner_optimizer: torch.optim.Optimizer, learning_rate
45140 group ["params" ] = [param ]
46141 self .inner_optimizer .param_groups .append (group )
47142
48- super (OptimizerOptimizer , self ).__init__ (params , {"weight_decay" : weight_decay , "decay_to_init" : decay_to_init })
143+ super (OptimizerOptimizer , self ).__init__ (params , {"weight_decay" : weight_decay })
49144
50145 @torch .no_grad ()
51146 def step (self , closure = None ):
52147 loss = None
53148 if closure is not None :
54149 loss = closure ()
55150
151+ self .weight_decay_cls (self )
152+
56153 for group in self .param_groups :
57154 for p in group ['params' ]:
58155 state = self .state [p ]
@@ -80,10 +177,11 @@ def step(self, closure=None):
80177
81178
82179class Sign (torch .optim .Optimizer ):
83- def __init__ (self , params , base : torch .optim .Optimizer , lr : float = 1 , weight_decay : float = 0 ,
84- decay_to_init : bool = False , eps : float = 1e-12 , graft_to_self : bool = True ):
85- super ().__init__ (params , {"weight_decay" : weight_decay , "decay_to_init" : decay_to_init , "lr" : lr , "eps" : eps ,
86- "graft_to_self" : graft_to_self })
180+ def __init__ (self , params , base : torch .optim .Optimizer , lr : float = 1 , weight_decay : float = 0 , eps : float = 1e-12 ,
181+ graft_to_self : bool = True , weight_decay_cls : Optional [WeightDecayChain ] = None ):
182+ self .weight_decay_cls = _default_decay (weight_decay_cls )
183+
184+ super ().__init__ (params , {"weight_decay" : weight_decay , "lr" : lr , "eps" : eps , "graft_to_self" : graft_to_self })
87185 self .base = base
88186
89187 @torch .no_grad ()
@@ -94,14 +192,8 @@ def step(self, closure=None):
94192 with torch .enable_grad ():
95193 loss = closure ()
96194
97- params_flat = []
98- for group in self .param_groups :
99- for p in group ["params" ]:
100- params_flat .append (p )
101- decay_weight_ (self .state [p ], p , group )
102-
103- params_flat = [torch .clone (p .detach ()) for p in params_flat ]
104-
195+ self .weight_decay_cls (self )
196+ params_flat = list (_param_iterator (self ))
105197 self .base .step ()
106198
107199 for group in self .param_groups :
@@ -150,10 +242,12 @@ class Graft(torch.optim.Optimizer):
150242 """
151243
152244 def __init__ (self , params , magnitude : torch .optim .Optimizer , direction : torch .optim .Optimizer ,
153- weight_decay : float = 0 , decay_to_init : bool = False , eps : float = 1e-12 , lr : float = 1 ):
154- super ().__init__ (params , {"weight_decay" : weight_decay , "decay_to_init" : decay_to_init , "lr" : lr , "eps" : eps })
245+ weight_decay : float = 0 , eps : float = 1e-12 , lr : float = 1 ,
246+ weight_decay_cls : Optional [WeightDecayChain ] = None ):
247+ super ().__init__ (params , {"weight_decay" : weight_decay , "lr" : lr , "eps" : eps })
155248 self .magnitude = magnitude
156249 self .direction = direction
250+ self .weight_decay_cls = _default_decay (weight_decay_cls )
157251
158252 @torch .no_grad ()
159253 def step (self , closure = None ):
@@ -163,13 +257,8 @@ def step(self, closure=None):
163257 with torch .enable_grad ():
164258 loss = closure ()
165259
166- params_flat = []
167- for group in self .param_groups :
168- for p in group ["params" ]:
169- params_flat .append (p )
170- decay_weight_ (self .state [p ], p , group )
171-
172- original_params = [torch .clone (p .detach ()) for p in params_flat ]
260+ self .weight_decay_cls (self )
261+ original_params = list (_param_iterator (self ))
173262
174263 self .magnitude .step ()
175264 magnitudes_flat = []
@@ -194,21 +283,16 @@ class TrueGrad(torch.optim.Optimizer):
194283 base_statistics : List [str ] = []
195284 shared_statistics : List [str ] = []
196285
197- def __init__ (self , params , lr : float = 1e-3 ,
198- betas : List [float ] = (),
199- eps : float = 1e-12 ,
200- weight_decay : float = 1e-2 ,
201- graft : bool = True ,
202- decay_to_init : bool = False ,
203- default_to_baseline : bool = False ,
204- enforce_baseline : bool = False ):
286+ def __init__ (self , params , lr : float = 1e-3 , betas : List [float ] = (), eps : float = 1e-12 ,
287+ weight_decay : float = 1e-2 , graft : bool = True , default_to_baseline : bool = False ,
288+ enforce_baseline : bool = False , weight_decay_cls : Optional [WeightDecayChain ] = None ):
205289 defaults = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
206- decay_to_init = decay_to_init , default_to_baseline = default_to_baseline ,
207- enforce_baseline = enforce_baseline )
290+ default_to_baseline = default_to_baseline , enforce_baseline = enforce_baseline )
208291 super (TrueGrad , self ).__init__ (params , defaults )
292+ self .weight_decay_cls = _default_decay (weight_decay_cls )
209293
210- def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], ** kwargs : Tensor
211- ) -> Tuple [ Optional [Tensor ], Optional [Tensor ], float ]:
294+ def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], ** kwargs : Tensor ) -> Tuple [
295+ Optional [Tensor ], Optional [Tensor ], float ]:
212296 raise NotImplementedError
213297
214298 @torch .no_grad ()
@@ -245,12 +329,7 @@ def step(self, closure=None):
245329 step_t = state ['step' ]
246330 step_t += 1
247331
248- # Perform stepweight decay
249- decay = group ['lr' ] * group ['weight_decay' ]
250- if group ["decay_to_init" ]:
251- p .add_ (state ["init" ] - p , alpha = decay )
252- else :
253- p .mul_ (1 - decay )
332+ self .weight_decay_cls (self )
254333
255334 step = step_t .item ()
256335
@@ -275,28 +354,18 @@ class TGAdamW(TrueGrad):
275354
276355 def __init__ (self , params , lr : float = 1e-3 ,
277356 betas : Union [Tuple [float , float ], Tuple [float , float , float ]] = (0.9 , 0.999 , 0.999 ),
278- eps : float = 1e-12 ,
279- weight_decay : float = 1e-2 ,
280- graft : bool = True ,
281- decay_to_init : bool = False ,
282- default_to_adam : bool = None ,
283- default_to_baseline : bool = None ,
284- enforce_baseline : bool = False ):
285- if default_to_baseline is None :
286- default_to_baseline = default_to_adam
287- elif default_to_adam is not None :
288- raise ValueError ("Can't set both default_to_baseline and default_to_adam, as both map to the same argument" )
289- if default_to_adam is not None :
290- warnings .warn ("default_to_adam is deprecated and will be replaced by default_to_baseline in April 2023" )
357+ eps : float = 1e-12 , weight_decay : float = 1e-2 , graft : bool = True ,
358+ default_to_baseline : bool = None , enforce_baseline : bool = False ,
359+ weight_decay_cls : Optional [WeightDecayChain ] = None ):
291360 if default_to_baseline is None :
292361 default_to_baseline = False
293362 super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
294- decay_to_init = decay_to_init , default_to_baseline = default_to_baseline ,
295- enforce_baseline = enforce_baseline )
363+ default_to_baseline = default_to_baseline , enforce_baseline = enforce_baseline ,
364+ weight_decay_cls = weight_decay_cls )
296365
297366 def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], exp_avg : Tensor ,
298- exp_avg_sq : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
299- ) -> Tuple [ Optional [Tensor ], Optional [Tensor ], float ]:
367+ exp_avg_sq : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None ) -> Tuple [
368+ Optional [Tensor ], Optional [Tensor ], float ]:
300369 if len (group ["betas" ]) == 2 :
301370 (beta1 , beta2 ), (_ , beta3 ) = group ["betas" ], group ["betas" ]
302371 else :
@@ -317,21 +386,17 @@ class TGLaProp(TrueGrad):
317386 base_statistics : List [str ] = ["exp_avg" , "exp_avg_sq" ]
318387
319388 def __init__ (self , params , lr : float = 1e-3 ,
320- betas : Union [Tuple [float , float ], Tuple [float , float , float , float ]] = (0.9 , 0.99 ),
321- eps : float = 1e-12 ,
322- weight_decay : float = 1e-2 ,
323- graft : bool = True ,
324- decay_to_init : bool = False ,
325- default_to_baseline : bool = False ,
326- enforce_baseline : bool = False ):
389+ betas : Union [Tuple [float , float ], Tuple [float , float , float , float ]] = (0.9 , 0.99 ), eps : float = 1e-12 ,
390+ weight_decay : float = 1e-2 , graft : bool = True , decay_to_init : bool = False ,
391+ default_to_baseline : bool = False , enforce_baseline : bool = False ,
392+ weight_decay_cls : Optional [WeightDecayChain ] = None ):
327393 super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
328- decay_to_init = decay_to_init , default_to_baseline = default_to_baseline ,
329- enforce_baseline = enforce_baseline )
394+ default_to_baseline = default_to_baseline , enforce_baseline = enforce_baseline ,
395+ weight_decay_cls = weight_decay_cls )
330396
331- def _inner (self , step : int , p : Parameter , group : Dict [str , Any ],
332- exp_avg : Optional [Tensor ] = None , exp_avg_sq : Optional [Tensor ] = None ,
333- exp_avg_true : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
334- ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
397+ def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], exp_avg : Optional [Tensor ] = None ,
398+ exp_avg_sq : Optional [Tensor ] = None , exp_avg_true : Optional [Tensor ] = None ,
399+ exp_avg_true_sq : Optional [Tensor ] = None ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
335400 if len (group ["betas" ]) == 2 :
336401 (beta1 , beta2 ), (beta3 , beta4 ) = group ["betas" ], group ["betas" ]
337402 else :
@@ -362,21 +427,16 @@ class TGRMSProp(TrueGrad):
362427 true_statistics : List [str ] = ["exp_avg_true_sq" ]
363428 base_statistics : List [str ] = ["exp_avg_sq" ]
364429
365- def __init__ (self , params , lr : float = 1e-3 ,
366- betas : Union [float , Tuple [float ], Tuple [float , float ]] = (0.9 ,),
367- eps : float = 1e-12 ,
368- weight_decay : float = 1e-2 ,
369- graft : bool = True ,
370- decay_to_init : bool = False ,
371- default_to_baseline : bool = False ,
372- enforce_baseline : bool = False ):
430+ def __init__ (self , params , lr : float = 1e-3 , betas : Union [float , Tuple [float ], Tuple [float , float ]] = (0.9 ,),
431+ eps : float = 1e-12 , weight_decay : float = 1e-2 , graft : bool = True ,
432+ default_to_baseline : bool = False , enforce_baseline : bool = False ,
433+ weight_decay_cls : Optional [WeightDecayChain ] = None ):
373434 super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
374- decay_to_init = decay_to_init , default_to_baseline = default_to_baseline ,
375- enforce_baseline = enforce_baseline )
435+ default_to_baseline = default_to_baseline , enforce_baseline = enforce_baseline ,
436+ weight_decay_cls = weight_decay_cls )
376437
377- def _inner (self , step : int , p : Parameter , group : Dict [str , Any ],
378- exp_avg_sq : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
379- ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
438+ def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], exp_avg_sq : Optional [Tensor ] = None ,
439+ exp_avg_true_sq : Optional [Tensor ] = None ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
380440 if isinstance (group ["betas" ], float ):
381441 beta1 = beta2 = group ["betas" ]
382442 elif len (group ["betas" ]) == 1 :
0 commit comments