Skip to content

Commit c3f39f1

Browse files
committed
update reinforce synapse and test cases
1 parent 97619e5 commit c3f39f1

File tree

2 files changed

+53
-31
lines changed

2 files changed

+53
-31
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from ngclearn.utils import tensorstats
1111
from ngclearn.utils.model_utils import create_function
1212

13+
def gaussian_logpdf(event, mean, stddev):
14+
scale_sqrd = stddev ** 2
15+
log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
16+
quadratic = (jax.lax.stop_gradient(event - 2 * mean) + mean)**2 / scale_sqrd
17+
return - 0.5 * (log_normalizer + quadratic)
1318

1419
class REINFORCESynapse(DenseSynapse):
1520

@@ -39,6 +44,8 @@ def __init__(
3944
# self.seed = Component(seed)
4045
self.accumulated_gradients = Compartment(jnp.zeros((input_dim, output_dim * 2)))
4146
self.decay = decay
47+
self.step_count = Compartment(jnp.zeros(()))
48+
self.learning_mask = Compartment(jnp.zeros(()))
4249

4350
@staticmethod
4451
def _compute_update(dt, inputs, rewards, act_fx, weights):
@@ -53,41 +60,44 @@ def _compute_update(dt, inputs, rewards, act_fx, weights):
5360
sample = epsilon * std + mean
5461
outputs = sample # the actual action that we take
5562
# Compute log probability density of the Gaussian
56-
log_prob = -0.5 * jnp.log(2 * jnp.pi) - logstd - 0.5 * ((sample - mean) / std) ** 2
63+
log_prob = gaussian_logpdf(sample, mean, std)
5764
log_prob = log_prob.sum(-1)
5865
# Compute objective (negative REINFORCE objective)
5966
objective = (-log_prob * rewards).mean() * 1e-2
6067
# Backward pass
6168
# Compute gradients manually based on the derivation
62-
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * (sample-mu)/sigma^2
63-
dlog_prob_dmean = (sample - mean) / (std ** 2)
69+
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
70+
dlog_prob_dmean = -(sample - mean) / (std ** 2)
6471
# dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
6572
dlog_prob_dlogstd = ((sample - mean) / std) ** 2 - 1.0
6673
# Compute gradients with respect to weights
6774
# Using chain rule: dL/dW_mu = dL/dmu * dmu/dW_mu = dL/dmu * activation^T
6875
# Similarly for W_logstd
69-
dL_dWmu = activation.T @ (-rewards[:, None] * dlog_prob_dmean) * 1e-2
70-
dL_dWlstd = activation.T @ (-rewards[:, None] * dlog_prob_dlogstd) * 1e-2
76+
# Gradient ascent instead of descent
77+
dL_dWmu = activation.T @ (rewards[:, None] * dlog_prob_dmean) * 1e-2
78+
dL_dWlstd = activation.T @ (rewards[:, None] * dlog_prob_dlogstd) * 1e-2
7179
# Update weights
7280
dW = jnp.concatenate([dL_dWmu, dL_dWlstd], axis=-1)
7381
# Finally, return metrics if needed
7482
return dW, objective, outputs
7583

76-
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients"])
84+
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count"])
7785
@staticmethod
78-
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, decay, accumulated_gradients):
86+
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count):
7987
dWeights, objective, outputs = REINFORCESynapse._compute_update(
8088
dt, inputs, rewards, act_fx, weights
8189
)
8290
## do a gradient ascent update/shift
83-
weights = weights + dWeights * eta
91+
weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0
8492
## enforce non-negativity
85-
eps = 0.01 # 0.001
93+
eps = 0.0 # 0.01 # 0.001
8694
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
87-
accumulated_gradients = accumulated_gradients * decay + dWeights
88-
return weights, dWeights, objective, outputs, accumulated_gradients
95+
step_count += 1
96+
accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
97+
step_count = step_count * (1 - learning_mask) # reset the step count to 0 when we have learned
98+
return weights, dWeights, objective, outputs, accumulated_gradients, step_count
8999

90-
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients"])
100+
@transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count"])
91101
@staticmethod
92102
def reset(batch_size, shape):
93103
preVals = jnp.zeros((batch_size, shape[0]))
@@ -98,7 +108,8 @@ def reset(batch_size, shape):
98108
rewards = jnp.zeros((batch_size,))
99109
dWeights = jnp.zeros(shape)
100110
accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2))
101-
return inputs, outputs, objective, rewards, dWeights, accumulated_gradients
111+
step_count = jnp.zeros(())
112+
return inputs, outputs, objective, rewards, dWeights, accumulated_gradients, step_count
102113

103114
@classmethod
104115
def help(cls): ## component help function

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ngcsimlib.context import Context
66
import numpy as np
77
np.random.seed(42)
8-
from ngclearn.components.synapses.modulated.REINFORCESynapse import REINFORCESynapse
8+
from ngclearn.components.synapses.modulated.REINFORCESynapse import REINFORCESynapse, gaussian_logpdf
99
from ngcsimlib.compilers import compile_command, wrap_command
1010
from numpy.testing import assert_array_equal
1111

@@ -52,42 +52,53 @@ def clamp_weights(x):
5252

5353
## check pre-synaptic STDP only
5454
# truth = jnp.array([[1.25]])
55+
np.random.seed(42)
5556
ctx.reset()
56-
clamp_weights(jnp.ones((1, 2)))
57-
clamp_rewards(jnp.ones((1, 1)))
58-
clamp_inputs(jnp.ones((1, 1)))
57+
clamp_weights(jnp.ones((1, 2)) * 2)
58+
clamp_rewards(jnp.ones((1, 1)) * 3)
59+
clamp_inputs(jnp.ones((1, 1)) * 0.5)
5960
ctx.adapt(t=1., dt=dt)
6061
# assert_array_equal(a.dWeights.value, truth)
61-
print(a.dWeights.value)
62-
62+
print(f"weights: {a.weights.value}")
63+
print(f"dWeights: {a.dWeights.value}")
64+
print(f"step_count: {a.step_count.value}")
65+
print(f"accumulated_gradients: {a.accumulated_gradients.value}")
66+
print(f"objective: {a.objective.value}")
6367

68+
np.random.seed(42)
6469
# JAX Grad output
6570
_act = jax.nn.tanh
6671
def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed):
6772
W_mu, W_logstd = params
68-
mean = _act(inputs) @ W_mu
69-
logstd = _act(inputs) @ W_logstd
73+
activation = _act(inputs)
74+
mean = activation @ W_mu
75+
logstd = activation @ W_logstd
7076
std = jnp.exp(logstd.clip(-10.0, 2.0))
71-
sample = jax.random.normal(seed, mean.shape) * std + mean
72-
# logp = gaussian_logpdf(sample, mean, std).sum(-1)
73-
logp = jax.scipy.stats.norm.logpdf(sample, mean, std).sum(-1)
77+
# sample = jax.random.normal(seed, mean.shape) * std + mean
78+
sample = jnp.asarray(np.random.normal(0, 1, mean.shape)) * std + mean
79+
logp = gaussian_logpdf(sample, mean, std).sum(-1)
80+
# logp = jax.scipy.stats.norm.logpdf(sample, mean, std).sum(-1)
7481
return (-logp * outputs).mean() * 1e-2
7582
grad_fn = jax.value_and_grad(fn)
7683

77-
weights_mu = jnp.ones((1, 1))
78-
weights_logstd = jnp.ones((1, 1))
79-
inputs = jnp.ones((1, 1))
80-
outputs = jnp.ones((1, 1))
84+
weights_mu = jnp.ones((1, 1)) * 2
85+
weights_logstd = jnp.ones((1, 1)) * 2
86+
inputs = jnp.ones((1, 1)) * 0.5
87+
outputs = jnp.ones((1, 1)) * 3 # reward
8188
objective, grads = grad_fn(
8289
(weights_mu, weights_logstd),
8390
inputs,
8491
outputs,
8592
jax.random.key(42)
8693
)
94+
print(f"expected grads: {grads}")
95+
print(f"expected objective: {objective}")
8796
np.testing.assert_allclose(
8897
a.dWeights.value[0],
89-
jnp.concatenate([grads[0], grads[1]], axis=-1),
90-
atol=1e-2
98+
# NOTE: Viet: negate the gradient because gradient in ngc-learn
99+
# is gradient ascent, while gradient in JAX is gradient descent
100+
-jnp.concatenate([grads[0], grads[1]], axis=-1),
101+
atol=1e-8
91102
) # NOTE: gradient is not exact due to different gradient computation, we need to inspect more closely
92103

93-
# test_REINFORCESynapse1()
104+
test_REINFORCESynapse1()

0 commit comments

Comments
 (0)