File tree Expand file tree Collapse file tree 1 file changed +6
-9
lines changed Expand file tree Collapse file tree 1 file changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -161,11 +161,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
161161 denom = (exp_avg_var .add_ (group ['eps' ]).sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
162162
163163 if not self .rectify :
164- if group ['adamd_debias_term' ]:
165- step_size = group ['lr' ]
166- else :
167- step_size = group ['lr' ] / bias_correction1
168-
164+ step_size = group ['lr' ]
165+ if not group ['adamd_debias_term' ]:
166+ step_size /= bias_correction1
169167 p .data .addcdiv_ (exp_avg , denom , value = - step_size )
170168 else :
171169 buffered = group ['buffer' ][int (state ['step' ] % 10 )]
@@ -189,10 +187,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
189187 / (n_sma_max - 2 )
190188 )
191189
192- if group ['adamd_debias_term' ]:
193- step_size = rt
194- else :
195- step_size = rt / bias_correction1
190+ step_size = rt
191+ if not group ['adamd_debias_term' ]:
192+ step_size /= bias_correction1
196193 elif self .degenerated_to_sgd :
197194 step_size = 1.0 / bias_correction1
198195 else :
You can’t perform that action at this time.
0 commit comments