Skip to content

Commit 51c2650

Browse files
author
Alexander Ororbia
committed
cleaned up utils.optim and wrote compliant NAG optim
1 parent 9517689 commit 51c2650

File tree

5 files changed

+104
-16
lines changed

5 files changed

+104
-16
lines changed

docs/source/ngclearn.utils.optim.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ ngclearn.utils.optim.adam module
1212
:undoc-members:
1313
:show-inheritance:
1414

15+
ngclearn.utils.optim.nag module
16+
-------------------------------
17+
18+
.. automodule:: ngclearn.utils.optim.nag
19+
:members:
20+
:undoc-members:
21+
:show-inheritance:
22+
1523
ngclearn.utils.optim.optim\_utils module
1624
----------------------------------------
1725

ngclearn/utils/optim/adam.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
# %%
22

3-
# from ngcsimlib.component import Component
4-
# from ngcsimlib.compartment import Compartment
5-
# from ngcsimlib.resolver import resolver
6-
73
import numpy as np
84
from jax import jit, numpy as jnp, random, nn, lax
95
from functools import partial
10-
import time
116

127

13-
def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps):
8+
def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps):
149
"""
1510
Runs one step of Adam over a set of parameters given updates.
1611
The dynamics for any set of parameters is as follows:
@@ -39,17 +34,17 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps):
3934
4035
beta2: 2nd moment control factor
4136
42-
time: current time t or iteration step/call to this Adam update
37+
time_step: current time t or iteration step/call to this Adam update
4338
4439
eps: numberical stability coefficient (for calculating final update)
4540
4641
Returns:
47-
adjusted parameter tensor (same shape as "param")
42+
adjusted parameter tensor (same shape as "param"), adjusted g1, adjusted g2
4843
"""
4944
_g1 = beta1 * g1 + (1. - beta1) * update
5045
_g2 = beta2 * g2 + (1. - beta2) * jnp.square(update)
51-
g1_unb = _g1 / (1. - jnp.power(beta1, time))
52-
g2_unb = _g2 / (1. - jnp.power(beta2, time))
46+
g1_unb = _g1 / (1. - jnp.power(beta1, time_step))
47+
g2_unb = _g2 / (1. - jnp.power(beta2, time_step))
5348
_param = param - lr * g1_unb/(jnp.sqrt(g2_unb) + eps)
5449
return _param, _g1, _g2
5550

@@ -83,9 +78,7 @@ def adam_step(opt_params, theta, updates, eta=0.001, beta1=0.9, beta2=0.999, eps
8378
new_g1 = []
8479
new_g2 = []
8580
for i in range(len(theta)):
86-
px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i],
87-
g2[i], eta, beta1,
88-
beta2, time_step, eps)
81+
px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i], g2[i], eta, beta1, beta2, time_step, eps)
8982
new_theta.append(px_i)
9083
new_g1.append(g1_i)
9184
new_g2.append(g2_i)

ngclearn/utils/optim/nag.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# %%
2+
3+
import numpy as np
4+
from jax import jit, numpy as jnp, random, nn, lax
5+
from functools import partial
6+
import time
7+
8+
9+
def step_update(param, update, phi_old, lr, mu, time_step):
10+
"""
11+
Runs one step of Nesterov's accelerated gradient (NAG) over a set of parameters given updates.
12+
The dynamics for any set of parameters is as follows:
13+
14+
| phi = param - update * lr
15+
| param = phi + (phi - phi_previous) * mu, where mu = 0 iff t <= 1 (first iteration)
16+
17+
Args:
18+
param: parameter tensor to change/adjust
19+
20+
update: update tensor to be applied to parameter tensor (must be same
21+
shape as "param")
22+
23+
phi_old: previous friction/momentum parameter
24+
25+
lr: global step size value to be applied to updates to parameters
26+
27+
mu: friction/momentum control factor
28+
29+
time_step: current time t or iteration step/call to this NAG update
30+
31+
Returns:
32+
adjusted parameter tensor (same shape as "param"), adjusted momentum/friction variable
33+
"""
34+
phi = param - update * lr ## do a phantom gradient adjustment step
35+
_param = phi + (phi - phi_old) * (mu * (time_step > 1.)) ## NAG-step
36+
_phi_old = phi
37+
return _param, _phi_old
38+
39+
@jit
40+
def nag_step(opt_params, theta, updates, eta=0.01, mu=0.9): ## apply adjustment to theta
41+
"""
42+
Implements Nesterov's accelerated gradient (NAG) algorithm as a decoupled update rule given adjustments produced
43+
by a credit assignment algorithm/process.
44+
45+
Args:
46+
opt_params: (ArrayLike) parameters of the optimization algorithm
47+
48+
theta: (ArrayLike) the weights of neural network
49+
50+
updates: (ArrayLike) the updates of neural network
51+
52+
eta: (float, optional) step size coefficient for NAG update (Default: 0.001)
53+
54+
mu: (float, optional) friction/momentum control factor. (Default: 0.9)
55+
56+
Returns:
57+
ArrayLike: opt_params. New opt params, ArrayLike: theta. The updated weights
58+
"""
59+
phi, time_step = opt_params
60+
time_step = time_step + 1
61+
new_theta = []
62+
new_phi = []
63+
for i in range(len(theta)):
64+
px_i, phi_i = step_update(theta[i], updates[i], phi[i], eta, mu, time_step)
65+
new_theta.append(px_i)
66+
new_phi.append(phi_i)
67+
return (new_phi, time_step), new_theta
68+
69+
@jit
70+
def nag_init(theta):
71+
time_step = jnp.asarray(0.0)
72+
phi = [jnp.zeros(theta[i].shape) for i in range(len(theta))]
73+
return phi, time_step
74+
75+
if __name__ == '__main__':
76+
weights = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])]
77+
updates = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])]
78+
opt_params = nag_init(weights)
79+
opt_params, theta = nag_step(opt_params, weights, updates)
80+
print(f"opt_params: {opt_params}, theta: {theta}")
81+
weights = theta
82+
print("##################")
83+
opt_params, theta = nag_step(opt_params, weights, updates)
84+
print(f"opt_params: {opt_params}, theta: {theta}")
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import functools
22
from .sgd import sgd_step, sgd_init
3+
from .nag import nag_step, nag_init
34
from .adam import adam_step, adam_init
45

56
def get_opt_init_fn(opt='adam'):
67
return {
78
'adam': adam_init,
9+
'nag': nag_init,
810
'sgd': sgd_init
911
}[opt]
1012

1113

1214
def get_opt_step_fn(opt='adam', **kwargs):
13-
# **kwargs here is the hyper parameters you want to pass in the optimization function
15+
## **kwargs here is the hyper-parameters you want to pass in the optimization function
1416
return {
1517
'adam': functools.partial(adam_step, **kwargs),
18+
'nag': functools.partial(nag_step, **kwargs),
1619
'sgd': functools.partial(sgd_step, **kwargs),
1720
}[opt]

ngclearn/utils/optim/sgd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def step_update(param, update, lr):
1515

1616
@jit
1717
def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to theta
18-
"""Return a params update
18+
"""
19+
Returns updated parameters in accordance to a stochastic gradient descent (SGD) recipe
1920
2021
Args:
2122
opt_params: (ArrayLike) parameters of the optimization algorithm
@@ -42,7 +43,6 @@ def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to thet
4243
def sgd_init(theta):
4344
return jnp.asarray(0.0)
4445

45-
4646
if __name__ == '__main__':
4747
opt_params, theta = sgd_step((2.0), [1.0, 1.0], [3.0, 4.0], 3e-2)
4848
print(f"opt_params: {opt_params}, theta: {theta}")

0 commit comments

Comments
 (0)