Skip to content

Commit 0e3f674

Browse files
committed
update test case for test_sLIFCell.py
1 parent c80f2b5 commit 0e3f674

File tree

1 file changed

+9
-28
lines changed

1 file changed

+9
-28
lines changed
Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
32
import numpy as np
43
np.random.seed(42)
54
from ngclearn.components import SLIFCell
6-
from ngcsimlib.compilers import compile_command, wrap_command
75
from numpy.testing import assert_array_equal
86

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
7+
from ngclearn import MethodProcess, Context
8+
149

1510
def test_sLIFCell1():
1611
name = "slif_ctx"
@@ -25,26 +20,12 @@ def test_sLIFCell1():
2520
name="a", n_units=1, tau_m=50., resist_m=10., thr=0.3, key=subkeys[0]
2621
)
2722

28-
#"""
29-
advance_process = (Process("advance_proc")
23+
advance_process = (MethodProcess("advance_proc")
3024
>> a.advance_state)
31-
#ctx.wrap_and_add_command(advance_process.pure, name="run")
32-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
33-
34-
reset_process = (Process("reset_proc")
25+
reset_process = (MethodProcess("reset_proc")
3526
>> a.reset)
36-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
37-
#"""
38-
39-
"""
40-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
41-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
42-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
43-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
44-
"""
4527

4628
## set up non-compiled utility commands
47-
@Context.dynamicCommand
4829
def clamp(x):
4930
a.j.set(x)
5031

@@ -54,15 +35,15 @@ def clamp(x):
5435
y_seq = jnp.asarray([[0., 1., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
5536

5637
outs = []
57-
ctx.reset()
38+
reset_process.run()
5839
for ts in range(x_seq.shape[1]):
5940
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
60-
ctx.clamp(x_t)
61-
ctx.run(t=ts * 1., dt=dt)
62-
outs.append(a.s.value)
41+
clamp(x_t)
42+
advance_process.run(t=ts * 1., dt=dt)
43+
outs.append(a.s.get())
6344
outs = jnp.concatenate(outs, axis=1)
6445

6546
## output should equal input
6647
assert_array_equal(outs, y_seq)
6748

68-
#test_sLIFCell1()
49+
test_sLIFCell1()

0 commit comments

Comments
 (0)