1- from collections import defaultdict
21from typing import Callable , Dict , List , Tuple
32
43import torch
54from torch import nn
65
76from pytorch_optimizer .base .optimizer import BaseOptimizer
8- from pytorch_optimizer .base .types import CLOSURE , DEFAULTS , LOSS , OPTIMIZER , STATE
7+ from pytorch_optimizer .base .types import CLOSURE , DEFAULTS , LOSS , OPTIMIZER
98
109
1110def polyval (x : torch .Tensor , coef : torch .Tensor ) -> torch .Tensor :
@@ -119,8 +118,9 @@ def __init__(
119118 self .s_prev = s_prev
120119 self .eps = eps
121120
121+ self .f_term = self .s_prev / self .erf_imag (1.0 / torch .sqrt (torch .tensor (2.0 )))
122+
122123 self .optimizer = optimizer
123- self .state : STATE = defaultdict (dict )
124124 self .defaults : DEFAULTS = optimizer .defaults
125125
126126 def __str__ (self ) -> str :
@@ -130,6 +130,10 @@ def __str__(self) -> str:
130130 def param_groups (self ):
131131 return self .optimizer .param_groups
132132
133+ @property
134+ def state (self ):
135+ return self .optimizer .state
136+
133137 @torch .no_grad ()
134138 def reset (self ):
135139 device = self .param_groups [0 ]['params' ][0 ].device
@@ -172,7 +176,7 @@ def backup_params_and_grads(self) -> Tuple[Dict, Dict]:
172176
173177 @torch .no_grad ()
174178 def trac_step (self , updates : Dict , grads : Dict ) -> None :
175- self .state ['step' ] += 1
179+ self .state ['trac' ][ ' step' ] += 1
176180
177181 deltas = {}
178182
@@ -181,13 +185,13 @@ def trac_step(self, updates: Dict, grads: Dict) -> None:
181185 h = torch .zeros ((1 ,), device = device )
182186 for group in self .param_groups :
183187 for p in group ['params' ]:
184- if p . grad is None :
188+ if grads [ p ] is None :
185189 continue
186190
187- theta_ref = self .state [p ]
191+ theta_ref = self .state ['trac' ][ p ]
188192 update = updates [p ]
189193
190- deltas [p ] = (update - theta_ref ) / ( torch .sum (self .state ['s' ]) + self .eps )
194+ deltas [p ] = (update - theta_ref ) / torch .sum (self .state ['trac' ][ ' s' ]). add_ ( self .eps )
191195 update .neg_ ().add_ (p )
192196
193197 grad , delta = grads [p ], deltas [p ]
@@ -197,36 +201,42 @@ def trac_step(self, updates: Dict, grads: Dict) -> None:
197201
198202 delta .add_ (update )
199203
200- s = self .state ['s' ]
201- betas = self .state ['betas' ]
202- variance = self .state ['variance' ]
203- sigma = self .state ['sigma' ]
204+ s = self .state ['trac' ][ ' s' ]
205+ betas = self .state ['trac' ][ ' betas' ]
206+ variance = self .state ['trac' ][ ' variance' ]
207+ sigma = self .state ['trac' ][ ' sigma' ]
204208
205209 variance .mul_ (betas .pow (2 )).add_ (h .pow (2 ))
206210 sigma .mul_ (betas ).sub_ (h )
207211
208- f_term = self .s_prev / self .erf_imag (1.0 / torch .sqrt (torch .tensor (2.0 )))
209- s_term = self .erf_imag (sigma / (torch .sqrt (torch .tensor (2.0 )) * variance .sqrt () + self .eps ))
210- s .copy_ (f_term * s_term )
212+ s_term = self .erf_imag (sigma / (2.0 * variance ).sqrt_ ().add_ (self .eps ))
213+ s_term .mul_ (self .f_term )
214+ s .copy_ (s_term )
215+
216+ scale = max (torch .sum (s ), 0.0 )
211217
212218 for group in self .param_groups :
213219 for p in group ['params' ]:
214220 if grads [p ] is None :
215221 continue
216222
217- p .copy_ (self .state [p ] + deltas [p ] * max (torch .sum (s ), 0.0 ))
223+ delta = deltas [p ]
224+ delta .mul_ (scale ).add_ (self .state ['trac' ][p ])
225+
226+ p .copy_ (delta )
218227
219228 @torch .no_grad ()
220229 def step (self , closure : CLOSURE = None ) -> LOSS :
230+ # TODO: backup is first to get the delta of param and grad, but it does not work.
221231 with torch .enable_grad ():
222232 loss = self .optimizer .step (closure )
223233
224234 updates , grads = self .backup_params_and_grads ()
225235
226- if len ( self .state ) == 0 :
227- device = updates [ next ( iter ( updates . keys ())) ].device
236+ if 'trac' not in self .state :
237+ device = self . param_groups [ 0 ][ 'params' ][ 0 ].device
228238
229- self .state = {
239+ self .state [ 'trac' ] = {
230240 'betas' : torch .tensor (self .betas , device = device ),
231241 's' : torch .zeros (len (self .betas ), device = device ),
232242 'variance' : torch .zeros (len (self .betas ), device = device ),
@@ -236,7 +246,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
236246
237247 for group in self .param_groups :
238248 for p in group ['params' ]:
239- self .state [p ] = updates [p ].clone ()
249+ self .state ['trac' ][ p ] = updates [p ].clone ()
240250
241251 self .trac_step (updates , grads )
242252
0 commit comments