File tree Expand file tree Collapse file tree 1 file changed +1
-6
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +1
-6
lines changed Original file line number Diff line number Diff line change 1- import numpy as np
21import torch
32
43from pytorch_optimizer .base .exception import NoSparseGradientError
@@ -38,8 +37,6 @@ def __init__(
3837
3938 self .maximize = maximize
4039
41- self .sq2 : float = np .sqrt (2 )
42-
4340 defaults : DEFAULTS = {
4441 'lr' : lr ,
4542 'betas' : betas ,
@@ -88,8 +85,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8885 bias_correction1 : float = self .debias (beta1 , group ['step' ])
8986 bias_correction2 : float = self .debias (beta2 , group ['step' ])
9087
91- step_size : float = group ['lr' ] * np .log (np .sqrt (group ['step' ] + 1 ) * self .sq2 )
92-
9388 for p in group ['params' ]:
9489 if p .grad is None :
9590 continue
@@ -128,6 +123,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
128123
129124 update = (m_tilde + g_tilde ) / v_tilde .sqrt ().add_ (group ['eps' ])
130125
131- p .add_ (update , alpha = - step_size )
126+ p .add_ (update , alpha = - group [ 'lr' ] )
132127
133128 return loss
You can’t perform that action at this time.
0 commit comments