Skip to content

Commit 1dc4bda

Browse files
committed
update code
1 parent 68d2435 commit 1dc4bda

File tree

2 files changed

+136
-5
lines changed

2 files changed

+136
-5
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ class REINFORCESynapse(DenseSynapse):
6767
# Define Functions
6868
def __init__(
6969
self, name, shape, eta=1e-4, decay=0.99, weight_init=None, resist_scale=1., act_fx=None,
70-
p_conn=1., w_bound=1., batch_size=1, seed=None, mu_act_fx=None, mu_out_min=-jnp.inf, mu_out_max=jnp.inf, **kwargs
70+
p_conn=1., w_bound=1., batch_size=1, seed=None, mu_act_fx=None, mu_out_min=-jnp.inf, mu_out_max=jnp.inf,
71+
scalar_stddev=-1.0, **kwargs
7172
) -> None:
7273
# This is because we have weights mu and weight log sigma
7374
input_dim, output_dim = shape
@@ -84,6 +85,7 @@ def __init__(
8485
self.mu_act_fx, self.dmu_act_fx = create_function(mu_act_fx if mu_act_fx is not None else "identity")
8586
self.mu_out_min = mu_out_min
8687
self.mu_out_max = mu_out_max
88+
self.scalar_stddev = scalar_stddev
8789

8890
## Compartment setup
8991
self.dWeights = Compartment(self.weights.value * 0)
@@ -99,7 +101,8 @@ def __init__(
99101
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
100102

101103
@staticmethod
102-
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max):
104+
def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
105+
learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
103106
# (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
104107
W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
105108
# Forward pass
@@ -109,6 +112,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
109112
logstd = activation @ W_logstd
110113
clip_logstd = jnp.clip(logstd, -10.0, 2.0)
111114
std = jnp.exp(clip_logstd)
115+
std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
112116
# Sample using reparameterization trick
113117
epsilon = jax.random.normal(seed, fx_mean.shape)
114118
sample = epsilon * std + fx_mean
@@ -139,6 +143,7 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
139143
dL_dstd * std
140144
)
141145
dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
146+
dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
142147

143148
# Update weights, negate the gradient because gradient ascent in ngc-learn
144149
dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
@@ -147,10 +152,10 @@ def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_a
147152

148153
@transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
149154
@staticmethod
150-
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max):
155+
def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
151156
main_seed, sub_seed = jax.random.split(seed)
152157
dWeights, objective, outputs = REINFORCESynapse._compute_update(
153-
dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max
158+
dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev
154159
)
155160
## do a gradient ascent update/shift
156161
weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jax.numpy as jnp
1919

2020
def test_REINFORCESynapse1():
21+
# Testing reinforce synapse with learning stddev
2122
name = "reinforce_ctx"
2223
## create seeding keys
2324
np.random.seed(42)
@@ -34,7 +35,8 @@ def test_REINFORCESynapse1():
3435
a = REINFORCESynapse(
3536
name="a", shape=(1,1), decay=decay,
3637
act_fx="tanh", key=subkeys[0], seed=initial_seed,
37-
mu_act_fx="tanh", mu_out_min=mu_out_min, mu_out_max=mu_out_max
38+
mu_act_fx="tanh", mu_out_min=mu_out_min, mu_out_max=mu_out_max,
39+
scalar_stddev=-1.0
3840
)
3941

4042
evolve_process = (Process("evolve_proc") >> a.evolve)
@@ -137,3 +139,127 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
137139

138140
# test_REINFORCESynapse1()
139141

142+
143+
def test_REINFORCESynapse2():
144+
# Testing reinforce synapse with scalar stddev = 2.0
145+
name = "reinforce_ctx"
146+
## create seeding keys
147+
np.random.seed(42)
148+
dkey = random.PRNGKey(1234)
149+
dkey, *subkeys = random.split(dkey, 6)
150+
dt = 1. # ms
151+
decay = 0.99
152+
initial_seed = 42
153+
mu_out_min = -jnp.inf
154+
mu_out_max = jnp.inf
155+
scalar_stddev = 2.0
156+
157+
# ---- build a simple Poisson cell system ----
158+
with Context(name) as ctx:
159+
a = REINFORCESynapse(
160+
name="a", shape=(1,1), decay=decay,
161+
act_fx="tanh", key=subkeys[0], seed=initial_seed,
162+
mu_act_fx="tanh", mu_out_min=mu_out_min, mu_out_max=mu_out_max,
163+
scalar_stddev=scalar_stddev
164+
)
165+
166+
evolve_process = (Process("evolve_proc") >> a.evolve)
167+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
168+
169+
reset_process = (Process("reset_proc") >> a.reset)
170+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
171+
172+
@Context.dynamicCommand
173+
def clamp_inputs(x):
174+
a.inputs.set(x)
175+
176+
@Context.dynamicCommand
177+
def clamp_rewards(x):
178+
assert x.ndim == 1, "Rewards must be a 1D array"
179+
a.rewards.set(x)
180+
181+
@Context.dynamicCommand
182+
def clamp_weights(x):
183+
a.weights.set(x)
184+
185+
# Function definition
186+
_act = jax.nn.tanh
187+
def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
188+
W_mu, W_logstd = params
189+
activation = _act(inputs)
190+
mean = activation @ W_mu
191+
mean = jax.nn.tanh(mean)
192+
# logstd = activation @ W_logstd
193+
# std = jnp.exp(logstd.clip(-10.0, 2.0))
194+
std = scalar_stddev
195+
sample = jax.random.normal(seed, mean.shape) * std + mean
196+
sample = jnp.clip(sample, mu_out_min, mu_out_max)
197+
logp = gaussian_logpdf(jax.lax.stop_gradient(sample), mean, std).sum(-1)
198+
return (-logp * outputs).mean() * 1e-2
199+
grad_fn = jax.value_and_grad(fn)
200+
201+
# Some setups
202+
expected_seed = jax.random.PRNGKey(initial_seed)
203+
expected_weights_mu = jnp.asarray([[0.13]])
204+
expected_weights_logstd = jnp.asarray([[0.04]])
205+
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
206+
initial_ngclearn_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)[None]
207+
expected_gradient_list = []
208+
ctx.reset()
209+
210+
# Loop through 3 steps
211+
for step in range(10):
212+
expected_seed, expected_subseed = jax.random.split(expected_seed)
213+
214+
# ---------------- Step {step} --------------------
215+
print(f"------------ [Step {step}] ------------")
216+
inputs = -1**step * jnp.ones((1, 1)) / 10 # * 0.5 * step / 10.0
217+
outputs = -1**step * jnp.ones((1,)) / 10 # * 3 * step / 10.0# reward
218+
# --------- ngclearn ---------
219+
clamp_weights(initial_ngclearn_weights)
220+
clamp_rewards(outputs)
221+
clamp_inputs(inputs)
222+
ctx.adapt(t=1., dt=dt)
223+
print(f"[ngclearn] objective: {a.objective.value}")
224+
print(f"[ngclearn] weights: {a.weights.value}")
225+
print(f"[ngclearn] dWeights: {a.dWeights.value}")
226+
print(f"[ngclearn] step_count: {a.step_count.value}")
227+
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
228+
# -------- Expectation ---------
229+
print("--------------")
230+
expected_objective, expected_grads = grad_fn(
231+
(expected_weights_mu, expected_weights_logstd),
232+
inputs,
233+
outputs,
234+
expected_subseed
235+
)
236+
# NOTE: Viet: negate the gradient because gradient in ngc-learn
237+
# is gradient ascent, while gradient in JAX is gradient descent
238+
expected_grads = -jnp.concatenate([expected_grads[0], expected_grads[1]], axis=-1)
239+
expected_gradient_list.append(expected_grads)
240+
print(f"[Expectation] expected_weights: {expected_weights}")
241+
print(f"[Expectation] dWeights: {expected_grads}")
242+
print(f"[Expectation] objective: {expected_objective}")
243+
np.testing.assert_allclose(
244+
a.dWeights.value[0],
245+
expected_grads,
246+
atol=1e-8
247+
)
248+
np.testing.assert_allclose(
249+
a.objective.value,
250+
expected_objective,
251+
atol=1e-8
252+
)
253+
print()
254+
255+
# Finally, check if the accumulated gradients are correct
256+
decay_list = jnp.asarray([decay**i for i in range(len(expected_gradient_list))])[::-1]
257+
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
258+
np.testing.assert_allclose(
259+
a.accumulated_gradients.value[0],
260+
expected_accumulated_gradients,
261+
atol=1e-9
262+
)
263+
264+
# test_REINFORCESynapse2()
265+

0 commit comments

Comments
 (0)