Skip to content

Commit e7f482b

Browse files
committed
update working reinforce synapse
1 parent e9d7068 commit e7f482b

File tree

2 files changed

+93
-118
lines changed

2 files changed

+93
-118
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@
1111
from ngclearn.utils.model_utils import create_function
1212

1313
def gaussian_logpdf(event, mean, stddev):
14-
scale_sqrd = stddev ** 2
15-
log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
16-
quadratic = (jax.lax.stop_gradient(event - 2 * mean) + mean)**2 / scale_sqrd
17-
return - 0.5 * (log_normalizer + quadratic)
14+
# scale_sqrd = stddev ** 2
15+
# log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
16+
# quadratic = (jax.lax.stop_gradient(event - 2 * mean) + mean)**2 / scale_sqrd
17+
# return - 0.5 * (log_normalizer + quadratic)
18+
return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(stddev) - 0.5 * ( (jax.lax.stop_gradient(event - 2 * mean) + mean) / stddev )**2
1819

1920
class REINFORCESynapse(DenseSynapse):
2021

2122
# Define Functions
2223
def __init__(
2324
self, name, shape, eta=1e-4, decay=0.99, weight_init=None, resist_scale=1., act_fx=None,
24-
p_conn=1., w_bound=1., batch_size=1, **kwargs
25+
p_conn=1., w_bound=1., batch_size=1, seed=None, **kwargs
2526
):
2627
# This is because we have weights mu and weight log sigma
2728
input_dim, output_dim = shape
@@ -46,48 +47,61 @@ def __init__(
4647
self.decay = decay
4748
self.step_count = Compartment(jnp.zeros(()))
4849
self.learning_mask = Compartment(jnp.zeros(()))
50+
# self.seed = Component(jnp.array(seed) if seed is not None else jnp.array(42, dtype=jnp.int32))
51+
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
4952

5053
@staticmethod
51-
def _compute_update(dt, inputs, rewards, act_fx, weights):
54+
def _compute_update(dt, inputs, rewards, act_fx, weights, seed):
5255
W_mu, W_logstd = jnp.split(weights, 2, axis=-1) # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
5356
# Forward pass
5457
activation = act_fx(inputs)
5558
mean = activation @ W_mu
5659
logstd = activation @ W_logstd
57-
std = jnp.exp(logstd.clip(-10.0, 2.0))
60+
clip_logstd = jnp.clip(logstd, -10.0, 2.0)
61+
std = jnp.exp(clip_logstd)
5862
# Sample using reparameterization trick
59-
epsilon = jnp.asarray(np.random.normal(0, 1, mean.shape))
63+
epsilon = jax.random.normal(seed, mean.shape)
6064
sample = epsilon * std + mean
6165
outputs = sample # the actual action that we take
6266
# Compute log probability density of the Gaussian
63-
log_prob = gaussian_logpdf(sample, mean, std)
64-
log_prob = log_prob.sum(-1)
67+
log_prob = gaussian_logpdf(sample, mean, std).sum(-1)
6568
# Compute objective (negative REINFORCE objective)
6669
objective = (-log_prob * rewards).mean() * 1e-2
70+
6771
# Backward pass
72+
batch_size = inputs.shape[0] # B
73+
dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1)
74+
6875
# Compute gradients manually based on the derivation
6976
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
7077
# -(sample - mean) instead of (sample - mean) because we are doing straight-through gradient in the log_prob function
7178
# therefore, computation including the mean in such function does not contribute to the gradient
7279
dlog_prob_dmean = -(sample - mean) / (std ** 2)
80+
dL_dmean = dL_dlogp * dlog_prob_dmean # (B, A)
81+
dL_dWmu = activation.T @ dL_dmean
82+
7383
# dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
74-
dlog_prob_dlogstd = ((sample - mean) / std) ** 2 - 1.0
75-
# Compute gradients with respect to weights
76-
# Using chain rule: dL/dW_mu = dL/dmu * dmu/dW_mu = dL/dmu * activation^T
77-
# Similarly for W_logstd
78-
# Gradient ascent instead of descent
79-
dL_dWmu = activation.T @ (rewards[:, None] * dlog_prob_dmean) * 1e-2
80-
dL_dWlstd = activation.T @ (rewards[:, None] * dlog_prob_dlogstd) * 1e-2
81-
# Update weights
82-
dW = jnp.concatenate([dL_dWmu, dL_dWlstd], axis=-1)
84+
dlog_prob_dlogstd = (sample - mean)**2 / std**3 - 1.0 / std
85+
dL_dstd = dL_dlogp * dlog_prob_dlogstd
86+
# Apply gradient clipping for logstd
87+
dL_dlogstd = jnp.where(
88+
(logstd <= -10.0) | (logstd >= 2.0),
89+
0.0, # Zero gradient when clipped
90+
dL_dstd * std
91+
)
92+
dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
93+
94+
# Update weights, negate the gradient because gradient ascent in ngc-learn
95+
dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
8396
# Finally, return metrics if needed
8497
return dW, objective, outputs
8598

86-
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count"])
99+
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
87100
@staticmethod
88-
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count):
101+
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed):
102+
main_seed, sub_seed = jax.random.split(seed)
89103
dWeights, objective, outputs = REINFORCESynapse._compute_update(
90-
dt, inputs, rewards, act_fx, weights
104+
dt, inputs, rewards, act_fx, weights, sub_seed
91105
)
92106
## do a gradient ascent update/shift
93107
weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0
@@ -97,9 +111,9 @@ def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, de
97111
step_count += 1
98112
accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
99113
step_count = step_count * (1 - learning_mask) # reset the step count to 0 when we have learned
100-
return weights, dWeights, objective, outputs, accumulated_gradients, step_count
114+
return weights, dWeights, objective, outputs, accumulated_gradients, step_count, main_seed
101115

102-
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count"])
116+
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
103117
@staticmethod
104118
def reset(batch_size, shape):
105119
preVals = jnp.zeros((batch_size, shape[0]))
@@ -111,7 +125,8 @@ def reset(batch_size, shape):
111125
dWeights = jnp.zeros(shape)
112126
accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2))
113127
step_count = jnp.zeros(())
114-
return inputs, outputs, objective, rewards, dWeights, accumulated_gradients, step_count
128+
seed = jax.random.PRNGKey(42)
129+
return inputs, outputs, objective, rewards, dWeights, accumulated_gradients, step_count, seed
115130

116131
@classmethod
117132
def help(cls): ## component help function

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 53 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ def test_REINFORCESynapse1():
2525
dkey, *subkeys = random.split(dkey, 6)
2626
dt = 1. # ms
2727
decay = 0.99
28+
initial_seed = 42
29+
2830
# ---- build a simple Poisson cell system ----
2931
with Context(name) as ctx:
3032
a = REINFORCESynapse(
31-
name="a", shape=(1,1), decay=decay, act_fx="tanh", key=subkeys[0]
33+
name="a", shape=(1,1), decay=decay, act_fx="tanh", key=subkeys[0], seed=initial_seed
3234
)
3335

3436
evolve_process = (Process("evolve_proc") >> a.evolve)
@@ -43,6 +45,7 @@ def clamp_inputs(x):
4345

4446
@Context.dynamicCommand
4547
def clamp_rewards(x):
48+
assert x.ndim == 1, "Rewards must be a 1D array"
4649
a.rewards.set(x)
4750

4851
@Context.dynamicCommand
@@ -51,20 +54,20 @@ def clamp_weights(x):
5154

5255
# Function definition
5356
_act = jax.nn.tanh
54-
def fn(params: dict, inputs: jax.Array, outputs: jax.Array):
57+
def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
5558
W_mu, W_logstd = params
5659
activation = _act(inputs)
5760
mean = activation @ W_mu
5861
logstd = activation @ W_logstd
5962
std = jnp.exp(logstd.clip(-10.0, 2.0))
60-
# sample = jax.random.normal(seed, mean.shape) * std + mean
61-
sample = jnp.asarray(np.random.normal(0, 1, mean.shape)) * std + mean
63+
sample = jax.random.normal(seed, mean.shape) * std + mean
6264
logp = gaussian_logpdf(sample, mean, std).sum(-1)
6365
# logp = jax.scipy.stats.norm.logpdf(sample, mean, std).sum(-1)
6466
return (-logp * outputs).mean() * 1e-2
6567
grad_fn = jax.value_and_grad(fn)
6668

6769
# Some setups
70+
expected_seed = jax.random.PRNGKey(initial_seed)
6871
expected_weights_mu = jnp.asarray([[0.13]])
6972
expected_weights_logstd = jnp.asarray([[0.04]])
7073
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
@@ -73,102 +76,59 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array):
7376
ctx.reset()
7477

7578
# Loop through 3 steps
76-
step = 1
77-
# ---------------- Step {step} --------------------
78-
print(f"------------ [Step {step}] ------------")
79-
inputs = -1**step * jnp.ones((1, 1)) / 10 # * 0.5 * step / 10.0
80-
outputs = -1**step * jnp.ones((1, 1)) / 10 # * 3 * step / 10.0# reward
81-
# --------- ngclearn ---------
82-
clamp_weights(initial_ngclearn_weights)
83-
clamp_rewards(outputs)
84-
clamp_inputs(inputs)
85-
np.random.seed(42)
86-
ctx.adapt(t=1., dt=dt)
87-
print(f"[ngclearn] objective: {a.objective.value}")
88-
print(f"[ngclearn] weights: {a.weights.value}")
89-
print(f"[ngclearn] dWeights: {a.dWeights.value}")
90-
print(f"[ngclearn] step_count: {a.step_count.value}")
91-
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
92-
# -------- Expectation ---------
93-
print("--------------")
94-
np.random.seed(42)
95-
expected_objective, expected_grads = grad_fn(
96-
(expected_weights_mu, expected_weights_logstd),
97-
inputs,
98-
outputs,
99-
)
100-
# NOTE: Viet: negate the gradient because gradient in ngc-learn
101-
# is gradient ascent, while gradient in JAX is gradient descent
102-
expected_grads = -jnp.concatenate([expected_grads[0], expected_grads[1]], axis=-1)
103-
expected_gradient_list.append(expected_grads)
104-
print(f"[Expectation] expected_weights: {expected_weights}")
105-
print(f"[Expectation] dWeights: {expected_grads}")
106-
print(f"[Expectation] objective: {expected_objective}")
107-
np.testing.assert_allclose(
108-
a.dWeights.value[0],
109-
expected_grads,
110-
atol=1e-8
111-
)
112-
np.testing.assert_allclose(
113-
a.objective.value,
114-
expected_objective,
115-
atol=1e-8
116-
)
117-
print()
118-
119-
120-
step = 2
121-
# ---------------- Step {step} --------------------
122-
print(f"------------ [Step {step}] ------------")
123-
inputs = -1**step * jnp.ones((1, 1)) / 10 # * 0.5 * step / 10.0
124-
outputs = -1**step * jnp.ones((1, 1)) / 10 # * 3 * step / 10.0# reward
125-
# --------- ngclearn ---------
126-
clamp_weights(initial_ngclearn_weights)
127-
clamp_rewards(outputs)
128-
clamp_inputs(inputs)
129-
np.random.seed(43)
130-
ctx.adapt(t=1., dt=dt)
131-
print(f"[ngclearn] objective: {a.objective.value}")
132-
print(f"[ngclearn] weights: {a.weights.value}")
133-
print(f"[ngclearn] dWeights: {a.dWeights.value}")
134-
print(f"[ngclearn] step_count: {a.step_count.value}")
135-
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
136-
# -------- Expectation ---------
137-
print("--------------")
138-
np.random.seed(43)
139-
expected_objective, expected_grads = grad_fn(
140-
(expected_weights_mu, expected_weights_logstd),
141-
inputs,
142-
outputs,
143-
)
144-
# NOTE: Viet: negate the gradient because gradient in ngc-learn
145-
# is gradient ascent, while gradient in JAX is gradient descent
146-
expected_grads = -jnp.concatenate([expected_grads[0], expected_grads[1]], axis=-1)
147-
expected_gradient_list.append(expected_grads)
148-
print(f"[Expectation] expected_weights: {expected_weights}")
149-
print(f"[Expectation] dWeights: {expected_grads}")
150-
print(f"[Expectation] objective: {expected_objective}")
151-
np.testing.assert_allclose(
152-
a.dWeights.value[0],
153-
expected_grads,
154-
atol=1e-8
155-
)
156-
np.testing.assert_allclose(
157-
a.objective.value,
158-
expected_objective,
159-
atol=1e-8
160-
)
161-
print()
79+
for step in range(10):
80+
expected_seed, expected_subseed = jax.random.split(expected_seed)
81+
82+
# ---------------- Step {step} --------------------
83+
print(f"------------ [Step {step}] ------------")
84+
inputs = -1**step * jnp.ones((1, 1)) / 10 # * 0.5 * step / 10.0
85+
outputs = -1**step * jnp.ones((1,)) / 10 # * 3 * step / 10.0# reward
86+
# --------- ngclearn ---------
87+
clamp_weights(initial_ngclearn_weights)
88+
clamp_rewards(outputs)
89+
clamp_inputs(inputs)
90+
ctx.adapt(t=1., dt=dt)
91+
print(f"[ngclearn] objective: {a.objective.value}")
92+
print(f"[ngclearn] weights: {a.weights.value}")
93+
print(f"[ngclearn] dWeights: {a.dWeights.value}")
94+
print(f"[ngclearn] step_count: {a.step_count.value}")
95+
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
96+
# -------- Expectation ---------
97+
print("--------------")
98+
expected_objective, expected_grads = grad_fn(
99+
(expected_weights_mu, expected_weights_logstd),
100+
inputs,
101+
outputs,
102+
expected_subseed
103+
)
104+
# NOTE: Viet: negate the gradient because gradient in ngc-learn
105+
# is gradient ascent, while gradient in JAX is gradient descent
106+
expected_grads = -jnp.concatenate([expected_grads[0], expected_grads[1]], axis=-1)
107+
expected_gradient_list.append(expected_grads)
108+
print(f"[Expectation] expected_weights: {expected_weights}")
109+
print(f"[Expectation] dWeights: {expected_grads}")
110+
print(f"[Expectation] objective: {expected_objective}")
111+
np.testing.assert_allclose(
112+
a.dWeights.value[0],
113+
expected_grads,
114+
atol=1e-8
115+
)
116+
np.testing.assert_allclose(
117+
a.objective.value,
118+
expected_objective,
119+
atol=1e-8
120+
)
121+
print()
162122

163123
# Finally, check if the accumulated gradients are correct
164-
decay_list = jnp.asarray([decay**1, decay**0])
124+
decay_list = jnp.asarray([decay**i for i in range(len(expected_gradient_list))])[::-1]
165125
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
166126
np.testing.assert_allclose(
167127
a.accumulated_gradients.value[0],
168128
expected_accumulated_gradients,
169-
atol=1e-8
129+
atol=1e-9
170130
)
171131

172132

173-
#test_REINFORCESynapse1()
133+
test_REINFORCESynapse1()
174134

0 commit comments

Comments
 (0)