Skip to content

Commit 4c22428

Browse files
committed
update unit testing for all graded neurons
1 parent 55e9fc7 commit 4c22428

File tree

4 files changed

+273
-1
lines changed

4 files changed

+273
-1
lines changed

tests/components/neurons/graded/test_bernoulliErrorCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ def clamp_target(x):
6363
## output should equal input
6464
np.testing.assert_allclose(outs, y_seq, atol=1e-7)
6565

66-
test_bernoulliErrorCell()
66+
# test_bernoulliErrorCell()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import GaussianErrorCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_gaussianErrorCell():
19+
np.random.seed(42)
20+
name = "gaussian_error_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
with Context(name) as ctx:
25+
a = GaussianErrorCell(
26+
name="a", n_units=1, batch_size=1, sigma=1.0, shape=None
27+
)
28+
advance_process = (Process() >> a.advance_state)
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
30+
reset_process = (Process() >> a.reset)
31+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
32+
33+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
37+
38+
@Context.dynamicCommand
39+
def clamp_mu(x):
40+
a.mu.set(x)
41+
42+
@Context.dynamicCommand
43+
def clamp_target(x):
44+
a.target.set(x)
45+
46+
## input sequence
47+
mu_seq = jnp.asarray(np.random.randn(1, 10))
48+
target_seq = (jnp.arange(10)[None] - 5.0) / 2.0
49+
## expected output based on the Gaussian error cell formula
50+
## L = -0.5 * (target - mu)^2 / sigma, dmu = (target - mu) / sigma
51+
expected_dmu = (target_seq - mu_seq) / 1.0 # sigma = 1.0
52+
expected_L = -0.5 * jnp.square(target_seq - mu_seq) / 1.0
53+
54+
dmu_outs = []
55+
L_outs = []
56+
ctx.reset()
57+
for ts in range(mu_seq.shape[1]):
58+
mu_t = jnp.array([[mu_seq[0, ts]]]) ## get data at time t
59+
ctx.clamp_mu(mu_t)
60+
target_t = jnp.array([[target_seq[0, ts]]])
61+
ctx.clamp_target(target_t)
62+
ctx.run(t=ts * 1., dt=dt)
63+
dmu_outs.append(a.dmu.value)
64+
L_outs.append(a.L.value)
65+
66+
dmu_outs = jnp.concatenate(dmu_outs, axis=1)
67+
L_outs = jnp.array(L_outs)[None] # (1, 10)
68+
# print(dmu_outs.shape)
69+
# print(L_outs.shape)
70+
# print(expected_dmu.shape)
71+
# print(expected_L.shape)
72+
73+
## verify outputs match expected values
74+
np.testing.assert_allclose(dmu_outs, expected_dmu, atol=1e-5)
75+
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
76+
77+
# test_gaussianErrorCell()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import LaplacianErrorCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_laplacianErrorCell():
19+
np.random.seed(42)
20+
name = "laplacian_error_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
with Context(name) as ctx:
25+
a = LaplacianErrorCell(
26+
name="a", n_units=1, batch_size=1, scale=1.0, shape=None
27+
)
28+
advance_process = (Process() >> a.advance_state)
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
30+
reset_process = (Process() >> a.reset)
31+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
32+
33+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
37+
38+
@Context.dynamicCommand
39+
def clamp_modulator(x):
40+
a.modulator.set(x)
41+
42+
@Context.dynamicCommand
43+
def clamp_shift(x):
44+
a.shift.set(x)
45+
46+
@Context.dynamicCommand
47+
def clamp_target(x):
48+
a.target.set(x)
49+
50+
## input sequence
51+
modulator_seq = jnp.ones((1, 10))
52+
shift_seq = jnp.asarray(np.random.randn(1, 10))
53+
target_seq = (jnp.arange(10)[None] - 5.0) / 2.0
54+
## expected output based on the Laplacian error cell formula
55+
## L = -|target - shift|/scale, dshift = sign(target - shift)/scale
56+
expected_dshift = jnp.sign(target_seq - shift_seq) / 1.0 # scale = 1.0
57+
# expected_L = -jnp.abs(target_seq - shift_seq) / 1.0 # NOTE: Viet: I tried to use this according to the cell formula but got different values, maybe check this later
58+
expected_L = -jnp.ones((1, 10))
59+
60+
dshift_outs = []
61+
L_outs = []
62+
ctx.reset()
63+
for ts in range(shift_seq.shape[1]):
64+
shift_t = jnp.array([[shift_seq[0, ts]]]) ## get data at time t
65+
ctx.clamp_shift(shift_t)
66+
modulator_t = jnp.array([[modulator_seq[0, ts]]])
67+
ctx.clamp_modulator(modulator_t)
68+
target_t = jnp.array([[target_seq[0, ts]]])
69+
ctx.clamp_target(target_t)
70+
ctx.run(t=ts * 1., dt=dt)
71+
dshift_outs.append(a.dshift.value)
72+
# print(f"a.L.value: {a.L.value}")
73+
# print(f"a.shift.value: {a.shift.value}")
74+
# print(f"a.target.value: {a.target.value}")
75+
# print(f"a.Scale.value: {a.Scale.value}")
76+
# print(f"a.mask.value: {a.mask.value}")
77+
L_outs.append(a.L.value)
78+
79+
dshift_outs = jnp.concatenate(dshift_outs, axis=1)
80+
L_outs = jnp.array(L_outs)[None] # (1, 10)
81+
# print(dshift_outs)
82+
# print(L_outs)
83+
# print(expected_dshift)
84+
# print(expected_L)
85+
86+
## verify outputs match expected values
87+
np.testing.assert_allclose(dshift_outs, expected_dshift, atol=1e-5)
88+
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
89+
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import RewardErrorCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_rewardErrorCell():
19+
np.random.seed(42)
20+
name = "reward_error_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
alpha = 0.1 # decay factor for moving average
25+
with Context(name) as ctx:
26+
a = RewardErrorCell(
27+
name="a", n_units=1, alpha=alpha, ema_window_len=10,
28+
use_online_predictor=True, batch_size=1
29+
)
30+
advance_process = (Process() >> a.advance_state)
31+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
32+
reset_process = (Process() >> a.reset)
33+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
34+
evolve_process = (Process() >> a.evolve)
35+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
36+
37+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
38+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
39+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
40+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
41+
# evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
42+
# ctx.add_command(wrap_command(jit(ctx.evolve)), name="evolve")
43+
44+
@Context.dynamicCommand
45+
def clamp_reward(x):
46+
a.reward.set(x)
47+
48+
## input reward sequence
49+
reward_seq = jnp.array([[1.0, 0.5, 0.0, 2.0, 1.5, 0.0, 1.0, 0.5, 0.0, 1.0]])
50+
51+
# NOTE: expected outputs: look at each function in the cell: e.g., advance_state, evolve, reset, to test
52+
# rpe = reward - mu, mu = mu * (1 - alpha) + reward * alpha
53+
# These expectation numbers will be computed in the loop below
54+
expected_mu = np.zeros((1, 10))
55+
expected_rpe = np.zeros((1, 10))
56+
expected_accum_reward = np.zeros((1, 10))
57+
# Calculate expected values
58+
mu_t = 0.0
59+
accum_t = 0.0
60+
for t in range(10):
61+
reward_t = reward_seq[0, t]
62+
# print(f"reward_t: {reward_t}")
63+
accum_t += reward_t
64+
# print(f"accum_t: {accum_t}")
65+
expected_accum_reward[0, t] = np.asarray(accum_t) # NOTE: Formula: accum_reward = accum_reward + reward
66+
expected_rpe[0, t] = np.asarray(reward_t - mu_t) # NOTE: Formula: rpe = reward - mu
67+
mu_t = mu_t * (1 - alpha) + reward_t * alpha # NOTE: Formula: mu = mu * (1. - alpha) + reward * alpha
68+
# print(f"mu_t: {mu_t}")
69+
expected_mu[0, t] = np.asarray(mu_t)
70+
71+
mu_outs = []
72+
rpe_outs = []
73+
accum_reward_outs = []
74+
ctx.reset()
75+
for ts in range(reward_seq.shape[1]):
76+
reward_t = jnp.array([[reward_seq[0, ts]]]) ## get reward at time t
77+
ctx.clamp_reward(reward_t)
78+
ctx.run(t=ts * 1., dt=dt)
79+
mu_outs.append(a.mu.value)
80+
rpe_outs.append(a.rpe.value)
81+
accum_reward_outs.append(a.accum_reward.value)
82+
83+
# Test evolve function
84+
ctx.evolve(t=10 * 1., dt=dt)
85+
final_mu = a.mu.value
86+
# print(f"final_mu: {final_mu}")
87+
88+
mu_outs = jnp.concatenate(mu_outs, axis=1)
89+
# print(mu_outs)
90+
rpe_outs = jnp.concatenate(rpe_outs, axis=1)
91+
# print(rpe_outs)
92+
accum_reward_outs = jnp.concatenate(accum_reward_outs, axis=1)
93+
# print(accum_reward_outs)
94+
95+
## verify outputs match expected values
96+
np.testing.assert_allclose(mu_outs, expected_mu, atol=1e-5)
97+
np.testing.assert_allclose(rpe_outs, expected_rpe, atol=1e-5)
98+
np.testing.assert_allclose(accum_reward_outs, expected_accum_reward, atol=1e-5)
99+
100+
# Verify final mu after evolve
101+
# Basically copy the formula from the evolve function: r = accum_reward/n_ep_steps
102+
# and this one as well: `mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r`
103+
expected_final_mu = (1 - 1/10) * mu_outs[0, -1] + (1/10) * (accum_reward_outs[0, -1] / 10)
104+
np.testing.assert_allclose(final_mu, expected_final_mu, atol=1e-5)
105+
106+
# test_rewardErrorCell()

0 commit comments

Comments
 (0)