File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -99,13 +99,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9999 momentum , fim = state ['momentum' ], state ['fim' ]
100100 fim .mul_ (curr_beta2 ).addcmul_ (grad , grad , value = 1.0 - curr_beta2 )
101101
102- rms_grad = torch .pow (grad , 2 ).mean ().sqrt_ ()
102+ rms_grad = grad .pow (2 ).mean ().sqrt_ ()
103103 curr_eps = min (rms_grad , 1 ) * group ['eps' ]
104104
105- fim_base = torch .pow (fim , group ['p' ]).add_ (curr_eps )
106- grad_nat = torch . div ( grad , fim_base )
105+ fim_base = fim .pow (group ['p' ]).add_ (curr_eps )
106+ grad_nat = grad / fim_base
107107
108- rms = torch .pow (grad_nat , 2 ).mean ().sqrt_ ()
108+ rms = grad_nat .pow (2 ).mean ().sqrt_ ()
109109 divisor = max (1 , rms ) / group ['clip' ]
110110 grad_nat .div_ (divisor )
111111
@@ -119,6 +119,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
119119
120120 grad_weights .mul_ (group ['weight_decay' ]).add_ (momentum )
121121
122- p .add_ (- grad_weights , alpha = group ['lr' ])
122+ p .add_ (grad_weights , alpha = - group ['lr' ])
123123
124124 return loss
You can’t perform that action at this time.
0 commit comments