Skip to content

Commit 3193b72

Browse files
committed
update test for bernoulli cell
1 parent f453623 commit 3193b72

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import BernoulliErrorCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_bernoulliErrorCell():
19+
np.random.seed(42)
20+
name = "bernoulli_error_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
with Context(name) as ctx:
25+
a = BernoulliErrorCell(
26+
name="a", n_units=1, batch_size=1, input_logits=False, shape=None
27+
)
28+
advance_process = (Process() >> a.advance_state)
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
30+
reset_process = (Process() >> a.reset)
31+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
32+
33+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
37+
38+
@Context.dynamicCommand
39+
def clamp(x):
40+
a.p.set(x)
41+
42+
@Context.dynamicCommand
43+
def clamp_target(x):
44+
a.target.set(x)
45+
46+
## input spike train
47+
x_seq = jnp.asarray(np.random.randn(1, 10))
48+
target_seq = (jnp.arange(10)[None] - 5.0) / 2.0
49+
## desired output/epsp pulses
50+
y_seq = jnp.asarray([[-2.8193381, -4976.9263, -2.1224928, -2939.0425, -1233.3916, -0.24662945, -708.30042, 0.28213939, 3550.8477, 1.3651246]], dtype=jnp.float32)
51+
52+
outs = []
53+
ctx.reset()
54+
for ts in range(x_seq.shape[1]):
55+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
56+
ctx.clamp(x_t)
57+
target_xt = jnp.array([[target_seq[0, ts]]])
58+
ctx.clamp_target(target_xt)
59+
ctx.run(t=ts * 1., dt=dt)
60+
outs.append(a.dp.value)
61+
outs = jnp.concatenate(outs, axis=1)
62+
# print(outs)
63+
## output should equal input
64+
np.testing.assert_allclose(outs, y_seq, atol=1e-7)
65+
66+
test_bernoulliErrorCell()

0 commit comments

Comments
 (0)