@@ -95,6 +95,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9595
9696 for group in self .param_groups :
9797 beta1 , beta2 = group ['betas' ]
98+ weight_decay : float = group ['weight_decay' ]
99+
98100 if self .rectify :
99101 n_sma_max : float = 2.0 / (1.0 - beta2 ) - 1.0
100102
@@ -106,13 +108,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
106108 if grad .is_sparse :
107109 raise NoSparseGradientError (self .__name__ )
108110
109- if grad .dtype in (torch .float16 , torch .bfloat16 ):
110- grad = grad .float ()
111-
112- p_fp32 = p
113- if p .dtype in (torch .float16 , torch .bfloat16 ):
114- p_fp32 = p_fp32 .float ()
115-
116111 state = self .state [p ]
117112 if len (state ) == 0 :
118113 state ['step' ] = 0
@@ -122,70 +117,65 @@ def step(self, closure: CLOSURE = None) -> LOSS:
122117 state ['max_exp_avg_var' ] = torch .zeros_like (p )
123118
124119 if self .weight_decouple :
125- decay : float = (
126- group ['lr' ] * group ['weight_decay' ] if not self .fixed_decay else group ['weight_decay' ]
127- )
128- p_fp32 .mul_ (1.0 - decay )
129- elif group ['weight_decay' ] != 0 :
130- grad .add_ (p_fp32 , alpha = group ['weight_decay' ])
131-
132- exp_avg , exp_avg_var = state ['exp_avg' ], state ['exp_avg_var' ]
120+ p .mul_ (1.0 - (group ['lr' ] * weight_decay if not self .fixed_decay else weight_decay ))
121+ elif weight_decay > 0.0 :
122+ grad .add_ (p , alpha = weight_decay )
133123
134124 state ['step' ] += 1
125+ exp_avg , exp_avg_var = state ['exp_avg' ], state ['exp_avg_var' ]
135126
136127 bias_correction1 = 1.0 - beta1 ** state ['step' ]
137- bias_correction2 = 1.0 - beta2 ** state ['step' ]
128+ bias_correction2_sq = math . sqrt ( 1.0 - beta2 ** state ['step' ])
138129
139130 exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
140131 grad_residual = grad - exp_avg
141- exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1.0 - beta2 )
142- exp_avg_var .add_ (group ['eps' ])
143- if group ['amsgrad' ]:
144- torch .max (state ['max_exp_avg_var' ], exp_avg_var , out = exp_avg_var )
132+ exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1.0 - beta2 ).add_ (group ['eps' ])
145133
146- de_nom = (exp_avg_var .sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
134+ if group ['amsgrad' ]:
135+ max_exp_avg_var = state ['max_exp_avg_var' ]
136+ torch .max (max_exp_avg_var , exp_avg_var , out = max_exp_avg_var )
137+ de_nom = max_exp_avg_var .sqrt ()
138+ else :
139+ de_nom = exp_avg_var .sqrt ()
140+ de_nom .div_ (bias_correction2_sq ).add_ (group ['eps' ])
147141
148142 if not self .rectify :
149- step_size = group ['lr' ]
150- if not group ['adamd_debias_term' ]:
151- step_size /= bias_correction1
152- p_fp32 .addcdiv_ (exp_avg , de_nom , value = - step_size )
143+ step_size : float = group ['lr' ] if group ['adamd_debias_term' ] else group ['lr' ] / bias_correction1
144+ p .addcdiv_ (exp_avg , de_nom , value = - step_size )
145+ continue
146+
147+ buffered = group ['buffer' ][state ['step' ] % 10 ]
148+ if state ['step' ] == buffered [0 ]:
149+ n_sma , step_size = buffered [1 ], buffered [2 ]
153150 else :
154- buffered = group ['buffer' ][state ['step' ] % 10 ]
155- if state ['step' ] == buffered [0 ]:
156- n_sma , step_size = buffered [1 ], buffered [2 ]
157- else :
158- buffered [0 ] = state ['step' ]
159- beta2_t = beta2 ** state ['step' ]
160- n_sma = n_sma_max - 2 * state ['step' ] * beta2_t / (1 - beta2_t )
161- buffered [1 ] = n_sma
162-
163- if n_sma >= self .n_sma_threshold :
164- step_size = math .sqrt (
165- (1 - beta2_t )
166- * (n_sma - 4 )
167- / (n_sma_max - 4 )
168- * (n_sma - 2 )
169- / n_sma
170- * n_sma_max
171- / (n_sma_max - 2 )
172- )
173- if not group ['adamd_debias_term' ]:
174- step_size /= bias_correction1
175- elif self .degenerated_to_sgd :
176- step_size = 1.0 / bias_correction1
177- else :
178- step_size = - 1
179-
180- buffered [2 ] = step_size
151+ buffered [0 ] = state ['step' ]
152+ beta2_t = beta2 ** state ['step' ]
153+ n_sma = n_sma_max - 2 * state ['step' ] * beta2_t / (1 - beta2_t )
154+ buffered [1 ] = n_sma
181155
182156 if n_sma >= self .n_sma_threshold :
183- de_nom = exp_avg_var .sqrt ().add_ (group ['eps' ])
184- p_fp32 .addcdiv_ (exp_avg , de_nom , value = - step_size * group ['lr' ])
185- elif step_size > 0 :
186- p_fp32 .add_ (exp_avg , alpha = - step_size * group ['lr' ])
157+ step_size = math .sqrt (
158+ (1 - beta2_t )
159+ * (n_sma - 4 )
160+ / (n_sma_max - 4 )
161+ * (n_sma - 2 )
162+ / n_sma
163+ * n_sma_max
164+ / (n_sma_max - 2 )
165+ )
166+ if not group ['adamd_debias_term' ]:
167+ step_size /= bias_correction1
168+ elif self .degenerated_to_sgd :
169+ step_size = 1.0 / bias_correction1
170+ else :
171+ step_size = - 1
172+
173+ buffered [2 ] = step_size
187174
188- if p .dtype in (torch .float16 , torch .bfloat16 ):
189- p .copy_ (p_fp32 )
175+ if n_sma >= self .n_sma_threshold :
176+ de_nom = exp_avg_var .sqrt ().add_ (group ['eps' ])
177+ p .addcdiv_ (exp_avg , de_nom , value = - step_size * group ['lr' ])
178+ elif step_size > 0 :
179+ p .add_ (exp_avg , alpha = - step_size * group ['lr' ])
190180
191181 return loss
0 commit comments