1- from typing import Tuple , Union
1+ import enum
2+ import warnings
3+ from typing import Tuple , Union , List , Dict , Any , Optional
24
35import torch
6+ from torch import Tensor
7+ from torch .nn import Parameter
48
59
6- class TGAdamW (torch .optim .Optimizer ):
10+ class BaseOptimizer (enum .Enum , str ):
11+ adam : str = "adam"
12+ laprop : str = "laprop"
13+
14+
15+ def ema_ (base : Tensor , update : Tensor , beta : float , step : int = 0 ):
16+ base .mul_ (beta ).add_ (update , alpha = 1 - beta )
17+ if not step :
18+ return base
19+ return base / (1 - beta ** step )
20+
21+
22+ def stable_sqrt (base : Tensor , eps : float ):
23+ return base .sqrt ().clamp (min = eps )
24+
25+
26+ def div_ema (base : Tensor , eps : float , base_sq : Tensor , update_sq : Tensor , beta_sq : float , step : int = 0 ):
27+ return base / stable_sqrt (ema_ (base_sq , update_sq , beta_sq , step ), eps )
28+
29+
30+ class TrueGrad (torch .optim .Optimizer ):
31+ true_statistics : List [str ] = []
32+ base_statistics : List [str ] = []
33+ shared_statistics : List [str ] = []
34+
735 def __init__ (self , params , lr : float = 1e-3 ,
8- betas : Union [ Tuple [ float , float ], Tuple [ float , float , float ]] = (0.9 , 0.999 , 0.999 ),
36+ betas : List [ float ] = (),
937 eps : float = 1e-12 ,
1038 weight_decay : float = 1e-2 ,
1139 graft : bool = True ,
1240 decay_to_init : bool = False ,
13- default_to_adam : bool = False ):
41+ default_to_baseline : bool = False ):
1442 defaults = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
15- decay_to_init = decay_to_init , default_to_adam = default_to_adam )
16- super (TGAdamW , self ).__init__ (params , defaults )
43+ decay_to_init = decay_to_init , default_to_baseline = default_to_baseline )
44+ super (TrueGrad , self ).__init__ (params , defaults )
45+
46+ def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], ** kwargs : Tensor
47+ ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
48+ raise NotImplementedError
1749
1850 @torch .no_grad ()
1951 def step (self , closure = None ):
@@ -23,37 +55,30 @@ def step(self, closure=None):
2355 with torch .enable_grad ():
2456 loss = closure ()
2557 for group in self .param_groups :
26- if len (group ["betas" ]) == 2 :
27- beta1 , beta2 = group ["betas" ]
28- beta3 = beta2
29- else :
30- beta1 , beta2 , beta3 = group ['betas' ]
31-
3258 for p in group ['params' ]:
3359 if p .grad is None :
3460 continue
35- do_adam = not hasattr (p , "sum_grad_squared" ) or p .sum_grad_squared is None
36- if not group ["default_to_adam " ] and do_adam :
61+ do_baseline = not hasattr (p , "sum_grad_squared" ) or p .sum_grad_squared is None
62+ if not group ["default_to_baseline " ] and do_baseline :
3763 raise ValueError (f"Parameter of shape { list (p .size ())} doesn't have `sum_grad_squared` attribute. "
3864 f"Make sure to use backpack." )
3965
4066 state = self .state [p ]
4167
4268 if len (state ) == 0 :
43- state ['step' ] = torch .tensor (0. )
44- state ['exp_avg' ] = torch .zeros_like (p , memory_format = torch .preserve_format )
45- if not do_adam :
46- state ['exp_avg_true_sq' ] = torch .zeros_like (p , memory_format = torch .preserve_format )
47- if do_adam or group ["graft" ]:
48- state ['exp_avg_sq' ] = torch .zeros_like (p , memory_format = torch .preserve_format )
69+ state ['step' ] = Tensor (0. )
70+ for s in self .shared_statistics :
71+ state [s ] = torch .zeros_like (p , memory_format = torch .preserve_format )
72+ if not do_baseline :
73+ for s in self .true_statistics :
74+ state [s ] = torch .zeros_like (p , memory_format = torch .preserve_format )
75+ if do_baseline or group ["graft" ]:
76+ for s in self .base_statistics :
77+ state [s ] = torch .zeros_like (p , memory_format = torch .preserve_format )
4978 if group ["decay_to_init" ]:
5079 state ["init" ] = torch .clone (p .detach ())
5180
52- exp_avg = state ['exp_avg' ]
53- exp_avg_true_sq = state ['exp_avg_true_sq' ]
5481 step_t = state ['step' ]
55-
56- # update step
5782 step_t += 1
5883
5984 # Perform stepweight decay
@@ -63,26 +88,94 @@ def step(self, closure=None):
6388 else :
6489 p .mul_ (1 - decay )
6590
66- exp_avg .mul_ (beta1 ).add_ (p .grad , alpha = 1 - beta1 )
67-
6891 step = step_t .item ()
69- alpha = - group ['lr' ] / (1 - beta1 ** step )
70-
71- if not do_adam :
72- exp_avg_true_sq .mul_ (beta3 ).add_ (p .sum_grad_squared , alpha = 1 - beta3 )
73- p .sum_grad_squared = None
74- denom = (exp_avg_true_sq / (1 - beta3 ** step )).sqrt ().add_ (group ['eps' ])
75- update = exp_avg / denom
7692
77- if group [ "graft" ] or do_adam :
78- exp_avg_sq = state ['exp_avg_sq' ]
79- exp_avg_sq . mul_ ( beta2 ). add_ ( p . grad . square (), alpha = 1 - beta2 )
80- adam_update = exp_avg / ( exp_avg_sq / ( 1 - beta2 ** step )). sqrt (). add_ ( group [ 'eps' ] )
93+ base_update , update , alpha = self . _inner ( step , p ,
94+ ** { k : state [k ] for k in self . shared_statistics },
95+ ** { k : state [ k ] for k in self . base_statistics },
96+ ** { k : state [ k ] for k in self . true_statistics } )
8197
82- if group ["graft" ] and not do_adam :
83- alpha = alpha * adam_update .norm () / update .norm ().add_ (group ['eps' ])
84- elif do_adam :
85- update = adam_update
98+ if group ["graft" ] and not do_baseline :
99+ alpha = alpha * base_update .norm () / update .norm ().add_ (group ['eps' ])
100+ elif do_baseline :
101+ update = base_update
86102
87103 p .add_ (update , alpha = alpha )
88104 return loss
105+
106+
107+ class TGAdamW (TrueGrad ):
108+ true_statistics : List [str ] = ["exp_avg_true_sq" ]
109+ base_statistics : List [str ] = ["exp_avg_sq" ]
110+ shared_statistics : List [str ] = ["exp_avg" ]
111+
112+ def __init__ (self , params , lr : float = 1e-3 ,
113+ betas : Union [Tuple [float , float ], Tuple [float , float , float ]] = (0.9 , 0.999 , 0.999 ),
114+ eps : float = 1e-12 ,
115+ weight_decay : float = 1e-2 ,
116+ graft : bool = True ,
117+ decay_to_init : bool = False ,
118+ default_to_adam : bool = None ,
119+ default_to_baseline : bool = None ):
120+ if default_to_baseline is None :
121+ default_to_baseline = default_to_adam
122+ elif default_to_adam is not None :
123+ raise ValueError ("Can't set both default_to_baseline and default_to_adam, as both map to the same argument" )
124+ if default_to_adam is not None :
125+ warnings .warn ("default_to_adam is deprecated and will be replaced by default_to_baseline in April 2023" )
126+ if default_to_baseline is None :
127+ default_to_baseline = False
128+ super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
129+ decay_to_init = decay_to_init , default_to_baseline = default_to_baseline )
130+
131+ def _inner (self , step : int , p : Parameter , do_baseline : bool , group : Dict [str , Any ], exp_avg : Tensor ,
132+ exp_avg_sq : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
133+ ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
134+ if len (group ["betas" ]) == 2 :
135+ (beta1 , beta2 ), (_ , beta3 ) = group ["betas" ], group ["betas" ]
136+ else :
137+ beta1 , beta2 , beta3 = group ['betas' ]
138+
139+ update , base_update , eps = None , None , group ["eps" ]
140+ ema_ (exp_avg , p .grad , beta1 )
141+ if exp_avg_true_sq is not None :
142+ update = div_ema (exp_avg , group ["eps" ], exp_avg_true_sq , p .sum_grad_squared , beta3 , step )
143+ if exp_avg_sq is not None :
144+ base_update = div_ema (exp_avg , group ["eps" ], exp_avg_sq , p .grad .square (), beta2 , step )
145+
146+ return base_update , update , - group ['lr' ] / (1 - beta1 ** step )
147+
148+
149+ class TGLaProp (TrueGrad ):
150+ true_statistics : List [str ] = ["exp_avg_true" , "exp_avg_true_sq" ]
151+ base_statistics : List [str ] = ["exp_avg" , "exp_avg_sq" ]
152+
153+ def __init__ (self , params , lr : float = 1e-3 ,
154+ betas : Union [Tuple [float , float ], Tuple [float , float , float , float ]] = (0.9 , 0.99 ),
155+ eps : float = 1e-12 ,
156+ weight_decay : float = 1e-2 ,
157+ graft : bool = True ,
158+ decay_to_init : bool = False ,
159+ default_to_baseline : bool = False ):
160+ super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
161+ decay_to_init = decay_to_init , default_to_baseline = default_to_baseline )
162+
163+ def _inner (self , step : int , p : Parameter , do_baseline : bool , group : Dict [str , Any ],
164+ exp_avg : Optional [Tensor ] = None , exp_avg_sq : Optional [Tensor ] = None ,
165+ exp_avg_true : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
166+ ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
167+ if len (group ["betas" ]) == 2 :
168+ (beta1 , beta2 ), (beta3 , beta4 ) = group ["betas" ], group ["betas" ]
169+ else :
170+ beta1 , beta2 , beta3 , beta4 = group ['betas' ]
171+
172+ update , base_update , alpha , eps = None , None , 1 , group ["eps" ]
173+ if exp_avg_true_sq is not None :
174+ update = ema_ (exp_avg_true , div_ema (p .grad , eps , exp_avg_true_sq , p .sum_grad_squared , beta4 , step ), beta3 )
175+ alpha = - group ['lr' ] / (1 - beta3 ** step )
176+
177+ if exp_avg_sq is not None :
178+ base_update = ema_ (exp_avg , div_ema (p .grad , eps , exp_avg_sq , p .grad .square (), beta2 , step ), beta1 )
179+ alpha = - group ['lr' ] / (1 - beta1 ** step ) # if grafting, beta3 issues are "grafted" away
180+
181+ return base_update , update , alpha
0 commit comments