Skip to content

Commit 05e0a7d

Browse files
author
Alexander Ororbia
committed
cleaned up utils.optim and wrote compliant NAG optim
1 parent 51c2650 commit 05e0a7d

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

ngclearn/utils/optim/adam.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functools import partial
66

77

8-
def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps):
8+
def step_update(param, update, g1, g2, eta, beta1, beta2, time_step, eps):
99
"""
1010
Runs one step of Adam over a set of parameters given updates.
1111
The dynamics for any set of parameters is as follows:
@@ -28,7 +28,7 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps):
2828
g2: second moment factor/correction factor to use in parameter update
2929
(must be same shape as "update")
3030
31-
lr: global step size value to be applied to updates to parameters
31+
eta: global step size value to be applied to updates to parameters
3232
3333
beta1: 1st moment control factor
3434
@@ -45,7 +45,7 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps):
4545
_g2 = beta2 * g2 + (1. - beta2) * jnp.square(update)
4646
g1_unb = _g1 / (1. - jnp.power(beta1, time_step))
4747
g2_unb = _g2 / (1. - jnp.power(beta2, time_step))
48-
_param = param - lr * g1_unb/(jnp.sqrt(g2_unb) + eps)
48+
_param = param - eta * g1_unb/(jnp.sqrt(g2_unb) + eps)
4949
return _param, _g1, _g2
5050

5151
@jit

ngclearn/utils/optim/nag.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77

88

9-
def step_update(param, update, phi_old, lr, mu, time_step):
9+
def step_update(param, update, phi_old, eta, mu, time_step):
1010
"""
1111
Runs one step of Nesterov's accelerated gradient (NAG) over a set of parameters given updates.
1212
The dynamics for any set of parameters is as follows:
@@ -22,7 +22,7 @@ def step_update(param, update, phi_old, lr, mu, time_step):
2222
2323
phi_old: previous friction/momentum parameter
2424
25-
lr: global step size value to be applied to updates to parameters
25+
eta: global step size value to be applied to updates to parameters
2626
2727
mu: friction/momentum control factor
2828
@@ -31,7 +31,7 @@ def step_update(param, update, phi_old, lr, mu, time_step):
3131
Returns:
3232
adjusted parameter tensor (same shape as "param"), adjusted momentum/friction variable
3333
"""
34-
phi = param - update * lr ## do a phantom gradient adjustment step
34+
phi = param - update * eta ## do a phantom gradient adjustment step
3535
_param = phi + (phi - phi_old) * (mu * (time_step > 1.)) ## NAG-step
3636
_phi_old = phi
3737
return _param, _phi_old

ngclearn/utils/optim/sgd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from jax import jit, numpy as jnp
22

3-
def step_update(param, update, lr):
3+
def step_update(param, update, eta):
44
"""
55
Runs one step of SGD over a set of parameters given updates.
66
77
Args:
8-
lr: global step size to apply when adjusting parameters
8+
eta: global step size to apply when adjusting parameters
99
1010
Returns:
1111
adjusted parameter tensor (same shape as "param")
1212
"""
13-
_param = param - lr * update
13+
_param = param - update * eta
1414
return _param
1515

1616
@jit

0 commit comments

Comments
 (0)