Skip to content

Commit 40812ff

Browse files
committed
update test code for more than 1 steps
1 parent 92def68 commit 40812ff

File tree

2 files changed

+107
-35
lines changed

2 files changed

+107
-35
lines changed

ngclearn/components/synapses/modulated/REINFORCESynapse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def _compute_update(dt, inputs, rewards, act_fx, weights):
6767
# Backward pass
6868
# Compute gradients manually based on the derivation
6969
# dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
70+
# -(sample - mean) instead of (sample - mean) because we are doing straight-through gradient in the log_prob function
71+
# therefore, computation including the mean in such function does not contribute to the gradient
7072
dlog_prob_dmean = -(sample - mean) / (std ** 2)
7173
# dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
7274
dlog_prob_dlogstd = ((sample - mean) / std) ** 2 - 1.0

tests/components/synapses/modulated/test_REINFORCESynapse.py

Lines changed: 105 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ def test_REINFORCESynapse1():
2424
dkey = random.PRNGKey(1234)
2525
dkey, *subkeys = random.split(dkey, 6)
2626
dt = 1. # ms
27+
decay = 0.99
2728
# ---- build a simple Poisson cell system ----
2829
with Context(name) as ctx:
2930
a = REINFORCESynapse(
30-
name="a", shape=(1,1), act_fx="tanh", key=subkeys[0]
31+
name="a", shape=(1,1), decay=decay, act_fx="tanh", key=subkeys[0]
3132
)
3233

3334
evolve_process = (Process() >> a.evolve)
@@ -48,27 +49,9 @@ def clamp_rewards(x):
4849
def clamp_weights(x):
4950
a.weights.set(x)
5051

51-
# a.weights.set(jnp.ones((1, 1)) * 0.1)
52-
53-
## check pre-synaptic STDP only
54-
# truth = jnp.array([[1.25]])
55-
np.random.seed(42)
56-
ctx.reset()
57-
clamp_weights(jnp.ones((1, 2)) * 2)
58-
clamp_rewards(jnp.ones((1, 1)) * 3)
59-
clamp_inputs(jnp.ones((1, 1)) * 0.5)
60-
ctx.adapt(t=1., dt=dt)
61-
# assert_array_equal(a.dWeights.value, truth)
62-
print(f"weights: {a.weights.value}")
63-
print(f"dWeights: {a.dWeights.value}")
64-
print(f"step_count: {a.step_count.value}")
65-
print(f"accumulated_gradients: {a.accumulated_gradients.value}")
66-
print(f"objective: {a.objective.value}")
67-
68-
np.random.seed(42)
69-
# JAX Grad output
52+
# Function definition
7053
_act = jax.nn.tanh
71-
def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed):
54+
def fn(params: dict, inputs: jax.Array, outputs: jax.Array):
7255
W_mu, W_logstd = params
7356
activation = _act(inputs)
7457
mean = activation @ W_mu
@@ -81,24 +64,111 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed):
8164
return (-logp * outputs).mean() * 1e-2
8265
grad_fn = jax.value_and_grad(fn)
8366

84-
weights_mu = jnp.ones((1, 1)) * 2
85-
weights_logstd = jnp.ones((1, 1)) * 2
86-
inputs = jnp.ones((1, 1)) * 0.5
87-
outputs = jnp.ones((1, 1)) * 3 # reward
88-
objective, grads = grad_fn(
89-
(weights_mu, weights_logstd),
67+
# Some setups
68+
expected_weights_mu = jnp.asarray([[0.13]])
69+
expected_weights_logstd = jnp.asarray([[0.04]])
70+
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
71+
initial_ngclearn_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)[None]
72+
expected_gradient_list = []
73+
ctx.reset()
74+
75+
# Loop through 3 steps
76+
step = 1
77+
# ---------------- Step {step} --------------------
78+
print(f"------------ [Step {step}] ------------")
79+
inputs = -1**step * jnp.ones((1, 1)) / 10 # * 0.5 * step / 10.0
80+
outputs = -1**step * jnp.ones((1, 1)) / 10 # * 3 * step / 10.0# reward
81+
# --------- ngclearn ---------
82+
clamp_weights(initial_ngclearn_weights)
83+
clamp_rewards(outputs)
84+
clamp_inputs(inputs)
85+
np.random.seed(42)
86+
ctx.adapt(t=1., dt=dt)
87+
print(f"[ngclearn] objective: {a.objective.value}")
88+
print(f"[ngclearn] weights: {a.weights.value}")
89+
print(f"[ngclearn] dWeights: {a.dWeights.value}")
90+
print(f"[ngclearn] step_count: {a.step_count.value}")
91+
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
92+
# -------- Expectation ---------
93+
print("--------------")
94+
np.random.seed(42)
95+
expected_objective, expected_grads = grad_fn(
96+
(expected_weights_mu, expected_weights_logstd),
97+
inputs,
98+
outputs,
99+
)
100+
# NOTE: Viet: negate the gradient because gradient in ngc-learn
101+
# is gradient ascent, while gradient in JAX is gradient descent
102+
expected_grads = -jnp.concatenate([expected_grads[0], expected_grads[1]], axis=-1)
103+
expected_gradient_list.append(expected_grads)
104+
print(f"[Expectation] expected_weights: {expected_weights}")
105+
print(f"[Expectation] dWeights: {expected_grads}")
106+
print(f"[Expectation] objective: {expected_objective}")
107+
np.testing.assert_allclose(
108+
a.dWeights.value[0],
109+
expected_grads,
110+
atol=1e-8
111+
)
112+
np.testing.assert_allclose(
113+
a.objective.value,
114+
expected_objective,
115+
atol=1e-8
116+
)
117+
print()
118+
119+
120+
step = 2
121+
# ---------------- Step {step} --------------------
122+
print(f"------------ [Step {step}] ------------")
123+
inputs = -1**step * jnp.ones((1, 1)) / 10 # * 0.5 * step / 10.0
124+
outputs = -1**step * jnp.ones((1, 1)) / 10 # * 3 * step / 10.0# reward
125+
# --------- ngclearn ---------
126+
clamp_weights(initial_ngclearn_weights)
127+
clamp_rewards(outputs)
128+
clamp_inputs(inputs)
129+
np.random.seed(43)
130+
ctx.adapt(t=1., dt=dt)
131+
print(f"[ngclearn] objective: {a.objective.value}")
132+
print(f"[ngclearn] weights: {a.weights.value}")
133+
print(f"[ngclearn] dWeights: {a.dWeights.value}")
134+
print(f"[ngclearn] step_count: {a.step_count.value}")
135+
print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
136+
# -------- Expectation ---------
137+
print("--------------")
138+
np.random.seed(43)
139+
expected_objective, expected_grads = grad_fn(
140+
(expected_weights_mu, expected_weights_logstd),
90141
inputs,
91142
outputs,
92-
jax.random.key(42)
93143
)
94-
print(f"expected grads: {grads}")
95-
print(f"expected objective: {objective}")
144+
# NOTE: Viet: negate the gradient because gradient in ngc-learn
145+
# is gradient ascent, while gradient in JAX is gradient descent
146+
expected_grads = -jnp.concatenate([expected_grads[0], expected_grads[1]], axis=-1)
147+
expected_gradient_list.append(expected_grads)
148+
print(f"[Expectation] expected_weights: {expected_weights}")
149+
print(f"[Expectation] dWeights: {expected_grads}")
150+
print(f"[Expectation] objective: {expected_objective}")
96151
np.testing.assert_allclose(
97152
a.dWeights.value[0],
98-
# NOTE: Viet: negate the gradient because gradient in ngc-learn
99-
# is gradient ascent, while gradient in JAX is gradient descent
100-
-jnp.concatenate([grads[0], grads[1]], axis=-1),
153+
expected_grads,
101154
atol=1e-8
102-
) # NOTE: gradient is not exact due to different gradient computation, we need to inspect more closely
155+
)
156+
np.testing.assert_allclose(
157+
a.objective.value,
158+
expected_objective,
159+
atol=1e-8
160+
)
161+
print()
162+
163+
# Finally, check if the accumulated gradients are correct
164+
decay_list = jnp.asarray([decay**1, decay**0])
165+
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
166+
np.testing.assert_allclose(
167+
a.accumulated_gradients.value[0],
168+
expected_accumulated_gradients,
169+
atol=1e-8
170+
)
171+
172+
173+
test_REINFORCESynapse1()
103174

104-
# test_REINFORCESynapse1()

0 commit comments

Comments
 (0)