File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -65,7 +65,7 @@ def reset(self):
6565 state = self .state [p ]
6666
6767 grad = p .grad
68- g_2 = grad ** 2
68+ g_2 = grad ** 2 # fmt: skip
6969
7070 state ['step' ] = 0
7171 state ['moments' ] = grad .div (g_2 .sqrt () + group ['eps' ]) + group ['weight_decay' ] * p
@@ -90,6 +90,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9090 bias_correction1 = 1.0 - beta1 ** group ['step' ]
9191 bias_correction2_sq = math .sqrt (1.0 - beta2 ** group ['step' ])
9292
93+ step_size : float = group ['lr' ] * bias_correction2_sq
94+ if not self .adamd_debias_term :
95+ step_size /= bias_correction1
96+
9397 for p in group ['params' ]:
9498 if p .grad is None :
9599 continue
@@ -120,10 +124,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
120124
121125 moments .mul_ (beta1 ).add_ (grad )
122126
123- step_size : float = group ['lr' ] * bias_correction2_sq
124- if not self .adamd_debias_term :
125- step_size /= bias_correction1
126-
127127 p .add_ (moments , alpha = - step_size )
128128
129129 return loss
You can’t perform that action at this time.
0 commit comments