Skip to content

Commit b154e4e

Browse files
committed
update reinforce synapse and testing
1 parent 242161e commit b154e4e

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class REINFORCESynapse(DenseSynapse):
1515

1616
# Define Functions
1717
def __init__(
18-
self, name, shape, eta=1e-4, weight_init=None, resist_scale=1., act_fx=None,
18+
self, name, shape, eta=1e-4, decay=0.99, weight_init=None, resist_scale=1., act_fx=None,
1919
p_conn=1., w_bound=1., batch_size=1, **kwargs
2020
):
2121
# This is because we have weights mu and weight log sigma
@@ -37,7 +37,8 @@ def __init__(
3737
self.rewards = Compartment(jnp.zeros((batch_size,))) # the normalized reward (r - r_hat), input compartment
3838
self.act_fx, self.dact_fx = create_function(act_fx if act_fx is not None else "identity")
3939
# self.seed = Component(seed)
40-
40+
self.accumulated_gradients = Compartment(jnp.zeros((input_dim, output_dim * 2)))
41+
self.decay = decay
4142

4243
@staticmethod
4344
def _compute_update(dt, inputs, rewards, act_fx, weights):
@@ -72,9 +73,9 @@ def _compute_update(dt, inputs, rewards, act_fx, weights):
7273
# Finally, return metrics if needed
7374
return dW, objective, outputs
7475

75-
@transition(output_compartments=["weights", "dWeights", "objective", "outputs"])
76+
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients"])
7677
@staticmethod
77-
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta):
78+
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, decay, accumulated_gradients):
7879
dWeights, objective, outputs = REINFORCESynapse._compute_update(
7980
dt, inputs, rewards, act_fx, weights
8081
)
@@ -83,9 +84,10 @@ def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta):
8384
## enforce non-negativity
8485
eps = 0.01 # 0.001
8586
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
86-
return weights, dWeights, objective, outputs
87+
accumulated_gradients = accumulated_gradients * decay + dWeights
88+
return weights, dWeights, objective, outputs, accumulated_gradients
8789

88-
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights"])
90+
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients"])
8991
@staticmethod
9092
def reset(batch_size, shape):
9193
preVals = jnp.zeros((batch_size, shape[0]))
@@ -95,7 +97,8 @@ def reset(batch_size, shape):
9597
objective = jnp.zeros(())
9698
rewards = jnp.zeros((batch_size,))
9799
dWeights = jnp.zeros(shape)
98-
return inputs, outputs, objective, rewards, dWeights
100+
accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2))
101+
return inputs, outputs, objective, rewards, dWeights, accumulated_gradients
99102

100103
@classmethod
101104
def help(cls): ## component help function
@@ -110,8 +113,8 @@ def help(cls): ## component help function
110113
}
111114
info = {cls.__name__: properties,
112115
"compartments": compartment_props,
113-
"dynamics": "outputs = [(W * Rscale) * inputs] ;"
114-
"dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
116+
# "dynamics": "outputs = [(W * Rscale) * inputs] ;"
117+
# "dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
115118
"hyperparameters": hyperparams}
116119
return info
117120

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# %%
22

3+
import jax
34
from jax import numpy as jnp, random, jit
45
from ngcsimlib.context import Context
56
import numpy as np
@@ -13,6 +14,9 @@
1314
from ngcsimlib.compartment import Compartment
1415
from ngcsimlib.context import Context
1516

17+
import jax
18+
import jax.numpy as jnp
19+
1620
def test_REINFORCESynapse1():
1721
name = "reinforce_ctx"
1822
## create seeding keys
@@ -40,16 +44,50 @@ def clamp_inputs(x):
4044
def clamp_rewards(x):
4145
a.rewards.set(x)
4246

47+
@Context.dynamicCommand
48+
def clamp_weights(x):
49+
a.weights.set(x)
50+
4351
# a.weights.set(jnp.ones((1, 1)) * 0.1)
4452

4553
## check pre-synaptic STDP only
4654
# truth = jnp.array([[1.25]])
4755
ctx.reset()
56+
clamp_weights(jnp.ones((1, 2)))
4857
clamp_rewards(jnp.ones((1, 1)))
4958
clamp_inputs(jnp.ones((1, 1)))
5059
ctx.adapt(t=1., dt=dt)
5160
# assert_array_equal(a.dWeights.value, truth)
5261
print(a.dWeights.value)
5362

54-
#test_REINFORCESynapse1()
5563

64+
# JAX Grad output
65+
_act = jax.nn.tanh
66+
def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed):
67+
W_mu, W_logstd = params
68+
mean = _act(inputs) @ W_mu
69+
logstd = _act(inputs) @ W_logstd
70+
std = jnp.exp(logstd.clip(-10.0, 2.0))
71+
sample = jax.random.normal(seed, mean.shape) * std + mean
72+
# logp = gaussian_logpdf(sample, mean, std).sum(-1)
73+
logp = jax.scipy.stats.norm.logpdf(sample, mean, std).sum(-1)
74+
return (-logp * outputs).mean() * 1e-2
75+
grad_fn = jax.value_and_grad(fn)
76+
77+
weights_mu = jnp.ones((1, 1))
78+
weights_logstd = jnp.ones((1, 1))
79+
inputs = jnp.ones((1, 1))
80+
outputs = jnp.ones((1, 1))
81+
objective, grads = grad_fn(
82+
(weights_mu, weights_logstd),
83+
inputs,
84+
outputs,
85+
jax.random.key(42)
86+
)
87+
np.testing.assert_allclose(
88+
a.dWeights.value[0],
89+
jnp.concatenate([grads[0], grads[1]], axis=-1),
90+
atol=1e-2
91+
) # NOTE: gradient is not exact due to different gradient computation, we need to inspect more closely
92+
93+
# test_REINFORCESynapse1()

0 commit comments

Comments
 (0)