Skip to content

Commit 7dbaccd

Browse files
committed
push reinforce synapse
1 parent 1b45362 commit 7dbaccd

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
9+
from ngclearn.components.synapses import DenseSynapse
10+
from ngclearn.utils import tensorstats
11+
from ngclearn.utils.model_utils import create_function
12+
13+
14+
class REINFORCESynapse(DenseSynapse):
15+
16+
# Define Functions
17+
def __init__(
18+
self, name, shape, eta=1e-4, weight_init=None, resist_scale=1., act_fx=None,
19+
p_conn=1., w_bound=1., batch_size=1, **kwargs
20+
):
21+
# This is because we have weights mu and weight log sigma
22+
input_dim, output_dim = shape
23+
super().__init__(name, (input_dim, output_dim * 2), weight_init, None, resist_scale,
24+
p_conn, batch_size=batch_size, **kwargs)
25+
26+
## Synaptic hyper-parameters
27+
self.shape = shape ## shape of synaptic efficacy matrix
28+
self.Rscale = resist_scale ## post-transformation scale factor
29+
self.w_bound = w_bound #1. ## soft weight constraint
30+
self.eta = eta ## learning rate
31+
32+
## Compartment setup
33+
self.dWeights = Compartment(self.weights.value * 0)
34+
# self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later
35+
self.objective = Compartment(jnp.zeros(()))
36+
self.outputs = Compartment(jnp.zeros((batch_size, output_dim)))
37+
self.rewards = Compartment(jnp.zeros((batch_size,))) # the normalized reward (r - r_hat), input compartment
38+
self.act_fx, self.dact_fx = create_function(act_fx if act_fx is not None else "identity")
39+
# self.seed = Component(seed)
40+
41+
42+
@staticmethod
43+
def _compute_update(dt, inputs, rewards, act_fx, weights):
44+
W_mu, W_logstd = jnp.split(weights, 2, axis=-1) # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
45+
# Forward pass
46+
activation = act_fx(inputs)
47+
mean = activation @ W_mu
48+
logstd = activation @ W_logstd
49+
std = jnp.exp(logstd.clip(-10.0, 2.0))
50+
# Sample using reparameterization trick
51+
epsilon = jnp.asarray(np.random.normal(0, 1, mean.shape))
52+
sample = epsilon * std + mean
53+
outputs = sample # the actual action that we take
54+
# Compute log probability density of the Gaussian
55+
log_prob = -0.5 * jnp.log(2 * jnp.pi) - logstd - 0.5 * ((sample - mean) / std) ** 2
56+
log_prob = log_prob.sum(-1)
57+
# Compute objective (negative REINFORCE objective)
58+
objective = (-log_prob * rewards).mean() * 1e-2
59+
# Backward pass
60+
# Compute gradients manually based on the derivation
61+
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * (sample-mu)/sigma^2
62+
dlog_prob_dmean = (sample - mean) / (std ** 2)
63+
# dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
64+
dlog_prob_dlogstd = ((sample - mean) / std) ** 2 - 1.0
65+
# Compute gradients with respect to weights
66+
# Using chain rule: dL/dW_mu = dL/dmu * dmu/dW_mu = dL/dmu * activation^T
67+
# Similarly for W_logstd
68+
dL_dWmu = activation.T @ (-rewards[:, None] * dlog_prob_dmean) * 1e-2
69+
dL_dWlstd = activation.T @ (-rewards[:, None] * dlog_prob_dlogstd) * 1e-2
70+
# Update weights
71+
dW = jnp.concatenate([dL_dWmu, dL_dWlstd], axis=-1)
72+
# Finally, return metrics if needed
73+
return dW, objective, outputs
74+
75+
@transition(output_compartments=["weights", "dWeights", "objective", "outputs"])
76+
@staticmethod
77+
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta):
78+
dWeights, objective, outputs = REINFORCESynapse._compute_update(
79+
dt, inputs, rewards, act_fx, weights
80+
)
81+
## do a gradient ascent update/shift
82+
weights = weights + dWeights * eta
83+
## enforce non-negativity
84+
eps = 0.01 # 0.001
85+
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
86+
return weights, dWeights, objective, outputs
87+
88+
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights"])
89+
@staticmethod
90+
def reset(batch_size, shape):
91+
preVals = jnp.zeros((batch_size, shape[0]))
92+
postVals = jnp.zeros((batch_size, shape[1]))
93+
inputs = preVals
94+
outputs = postVals
95+
objective = jnp.zeros(())
96+
rewards = jnp.zeros((batch_size,))
97+
dWeights = jnp.zeros(shape)
98+
return inputs, outputs, objective, rewards, dWeights
99+
100+
@classmethod
101+
def help(cls): ## component help function
102+
properties = {
103+
104+
}
105+
compartment_props = {
106+
107+
}
108+
hyperparams = {
109+
110+
}
111+
info = {cls.__name__: properties,
112+
"compartments": compartment_props,
113+
"dynamics": "outputs = [(W * Rscale) * inputs] ;"
114+
"dW_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
115+
"hyperparameters": hyperparams}
116+
return info
117+
118+
def __repr__(self):
119+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
120+
maxlen = max(len(c) for c in comps) + 5
121+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
122+
for c in comps:
123+
stats = tensorstats(getattr(self, c).value)
124+
if stats is not None:
125+
line = [f"{k}: {v}" for k, v in stats.items()]
126+
line = ", ".join(line)
127+
else:
128+
line = "None"
129+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
130+
return lines
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components.synapses.hebbian.REINFORCESynapse import REINFORCESynapse
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
16+
def test_REINFORCESynapse1():
17+
name = "reinforce_ctx"
18+
## create seeding keys
19+
dkey = random.PRNGKey(1234)
20+
dkey, *subkeys = random.split(dkey, 6)
21+
dt = 1. # ms
22+
# ---- build a simple Poisson cell system ----
23+
with Context(name) as ctx:
24+
a = REINFORCESynapse(
25+
name="a", shape=(1,1), act_fx="tanh", key=subkeys[0]
26+
)
27+
28+
evolve_process = (Process() >> a.evolve)
29+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
30+
31+
reset_process = (Process() >> a.reset)
32+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
33+
34+
@Context.dynamicCommand
35+
def clamp_inputs(x):
36+
a.inputs.set(x)
37+
38+
@Context.dynamicCommand
39+
def clamp_rewards(x):
40+
a.rewards.set(x)
41+
42+
# a.weights.set(jnp.ones((1, 1)) * 0.1)
43+
44+
## check pre-synaptic STDP only
45+
# truth = jnp.array([[1.25]])
46+
ctx.reset()
47+
clamp_rewards(jnp.ones((1, 1)))
48+
clamp_inputs(jnp.ones((1, 1)))
49+
ctx.adapt(t=1., dt=dt)
50+
# assert_array_equal(a.dWeights.value, truth)
51+
print(a.dWeights.value)
52+
53+
# test_REINFORCESynapse1()
54+

0 commit comments

Comments
 (0)