Skip to content

Commit b51cfd0

Browse files
committed
example rate cell test
1 parent 1f51c2e commit b51cfd0

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import RateCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
8+
def test_rateCell1():
9+
## create seeding keys
10+
dkey = random.PRNGKey(1234)
11+
dkey, *subkeys = random.split(dkey, 6)
12+
# in_dim = 9 # ... dimension of patch data ...
13+
# hid_dim = 9 # ... number of atoms in the dictionary matrix
14+
dt = 1. # ms
15+
T = 300 # ms # (OR) number of E-steps to take during inference
16+
# ---- build a sparse coding linear generative model with a Cauchy prior ----
17+
with Context("Circuit") as circuit:
18+
a = RateCell(name="a", n_units=1, tau_m=0.,
19+
act_fx="identity", key=subkeys[0])
20+
b = RateCell(name="b", n_units=1, tau_m=0.,
21+
act_fx="identity", key=subkeys[1])
22+
23+
# wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab`
24+
25+
# wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b`
26+
b.j << a.zF
27+
28+
## create and compile core simulation commands
29+
reset_cmd, reset_args = circuit.compile_by_key(a, b, compile_key="reset")
30+
circuit.add_command(wrap_command(jit(circuit.reset)), name="reset")
31+
32+
advance_cmd, advance_args = circuit.compile_by_key(a, b,
33+
compile_key="advance_state")
34+
circuit.add_command(wrap_command(jit(circuit.advance_state)), name="advance")
35+
36+
37+
## set up non-compiled utility commands
38+
@Context.dynamicCommand
39+
def clamp(x):
40+
a.j.set(x)
41+
42+
x_seq = jnp.asarray([[1, 1, 0, 0, 1]], dtype=jnp.float32)
43+
44+
circuit.reset()
45+
for ts in range(x_seq.shape[1]):
46+
x_t = jnp.expand_dims(x_seq[0,ts], axis=0) ## get data at time t
47+
circuit.clamp(x_t)
48+
circuit.advance(t=ts*1., dt=1.)
49+
50+
print(a.zF.value)
51+
# assertion here if needed!

0 commit comments

Comments
 (0)