@@ -30,6 +30,7 @@ def __init__(
3030 lr : float = 1e-3 ,
3131 betas : BETAS = (0.9 , 0.999 ),
3232 weight_decay : float = 0.0 ,
33+ n_sma_threshold : int = 5 ,
3334 degenerated_to_sgd : bool = True ,
3435 eps : float = 1e-8 ,
3536 ):
@@ -38,18 +39,32 @@ def __init__(
3839 :param lr: float. learning rate.
3940 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4041 :param weight_decay: float. weight decay (L2 penalty)
42+ :param n_sma_threshold: int. (recommended is 5)
4143 :param degenerated_to_sgd: float.
4244 :param eps: float. term added to the denominator to improve numerical stability
4345 """
4446 self .lr = lr
4547 self .betas = betas
4648 self .weight_decay = weight_decay
49+ self .n_sma_threshold = n_sma_threshold
4750 self .degenerated_to_sgd = degenerated_to_sgd
4851 self .eps = eps
4952
5053 self .check_valid_parameters ()
5154
52- defaults : DEFAULTS = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay )
55+ if isinstance (params , (list , tuple )) and len (params ) > 0 and isinstance (params [0 ], dict ):
56+ for param in params :
57+ if 'betas' in param and (param ['betas' ][0 ] != betas [0 ] or param ['betas' ][1 ] != betas [1 ]):
58+ param ['buffer' ] = [[None , None , None ] for _ in range (10 )]
59+
60+ defaults : DEFAULTS = dict (
61+ lr = lr ,
62+ betas = betas ,
63+ eps = eps ,
64+ weight_decay = weight_decay ,
65+ buffer = [[None , None , None ] for _ in range (10 )],
66+ )
67+
5368 super ().__init__ (params , defaults )
5469
5570 def check_valid_parameters (self ):
@@ -77,17 +92,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7792 if p .grad is None :
7893 continue
7994
80- grad = p .grad .data
95+ grad = p .grad .data . float ()
8196 if grad .is_sparse :
8297 raise RuntimeError ('diffGrad does not support sparse gradients' )
8398
99+ p_data_fp32 = p .data .float ()
84100 state = self .state [p ]
85101
86102 if len (state ) == 0 :
87103 state ['step' ] = 0
88- state ['exp_avg' ] = torch .zeros_like (p .data )
89- state ['exp_avg_sq' ] = torch .zeros_like (p .data )
90- state ['previous_grad' ] = torch .zeros_like (p .data )
104+ state ['exp_avg' ] = torch .zeros_like (p_data_fp32 )
105+ state ['exp_avg_sq' ] = torch .zeros_like (p_data_fp32 )
106+ state ['previous_grad' ] = torch .zeros_like (p_data_fp32 )
107+ else :
108+ state ['exp_avg' ] = state ['exp_avg' ].type_as (p_data_fp32 )
109+ state ['exp_avg_sq' ] = state ['exp_avg_sq' ].type_as (p_data_fp32 )
110+ state ['previous_grad' ] = state ['previous_grad' ].type_as (p_data_fp32 )
91111
92112 exp_avg , exp_avg_sq , previous_grad = (
93113 state ['exp_avg' ],
@@ -98,27 +118,55 @@ def step(self, closure: CLOSURE = None) -> LOSS:
98118
99119 state ['step' ] += 1
100120
101- if group ['weight_decay' ] != 0 :
102- grad .add_ (group ['weight_decay' ], p .data )
103-
104- # Decay the first and second moment running average coefficient
105121 exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
106122 exp_avg_sq .mul_ (beta2 ).addcmul_ (1 - beta2 , grad , grad )
107- denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
108-
109- bias_correction1 = 1 - beta1 ** state ['step' ]
110- bias_correction2 = 1 - beta2 ** state ['step' ]
111123
112124 # compute diffGrad coefficient (dfc)
113125 diff = abs (previous_grad - grad )
114126 dfc = 1.0 / (1.0 + torch .exp (- diff ))
115- state ['previous_grad' ] = grad .clone ()
116-
117- # update momentum with dfc
118- exp_avg1 = exp_avg * dfc
119127
120- step_size = group [ 'lr ' ] * math . sqrt ( bias_correction2 ) / bias_correction1
128+ state [ 'previous_grad ' ] = grad . clone ()
121129
122- p .data .addcdiv_ (- step_size , exp_avg1 , denom )
130+ buffered = group ['buffer' ][int (state ['step' ] % 10 )]
131+ if state ['step' ] == buffered [0 ]:
132+ n_sma , step_size = buffered [1 ], buffered [2 ]
133+ else :
134+ buffered [0 ] = state ['step' ]
135+ beta2_t = beta2 ** state ['step' ]
136+ n_sma_max = 2.0 / (1.0 - beta2 ) - 1.0
137+ n_sma = n_sma_max - 2.0 * state ['step' ] * beta2_t / (1.0 - beta2_t )
138+ buffered [1 ] = n_sma
139+
140+ if n_sma >= self .n_sma_threshold :
141+ step_size = math .sqrt (
142+ (1 - beta2_t )
143+ * (n_sma - 4 )
144+ / (n_sma_max - 4 )
145+ * (n_sma - 2 )
146+ / n_sma
147+ * n_sma_max
148+ / (n_sma_max - 2 )
149+ ) / (1.0 - beta1 ** state ['step' ])
150+ elif self .degenerated_to_sgd :
151+ step_size = 1.0 / (1 - beta1 ** state ['step' ])
152+ else :
153+ step_size = - 1
154+ buffered [2 ] = step_size
155+
156+ if n_sma >= self .n_sma_threshold :
157+ if group ['weight_decay' ] != 0 :
158+ p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
159+
160+ denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
161+
162+ # update momentum with dfc
163+ p_data_fp32 .addcdiv_ (- step_size * group ['lr' ], exp_avg * dfc .float (), denom )
164+ p .data .copy_ (p_data_fp32 )
165+ elif step_size > 0 :
166+ if group ['weight_decay' ] != 0 :
167+ p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
168+
169+ p_data_fp32 .add_ (- step_size * group ['lr' ], exp_avg )
170+ p .data .copy_ (p_data_fp32 )
123171
124172 return loss
0 commit comments