11from jax import numpy as jnp , random , jit
2- from ngcsimlib .context import Context
32import numpy as np
43np .random .seed (42 )
54from ngclearn .components import SLIFCell
6- from ngcsimlib .compilers import compile_command , wrap_command
75from numpy .testing import assert_array_equal
86
9- from ngcsimlib .compilers .process import Process , transition
10- from ngcsimlib .component import Component
11- from ngcsimlib .compartment import Compartment
12- from ngcsimlib .context import Context
13- from ngcsimlib .utils .compartment import Get_Compartment_Batch
7+ from ngclearn import MethodProcess , Context
8+
149
1510def test_sLIFCell1 ():
1611 name = "slif_ctx"
@@ -25,26 +20,12 @@ def test_sLIFCell1():
2520 name = "a" , n_units = 1 , tau_m = 50. , resist_m = 10. , thr = 0.3 , key = subkeys [0 ]
2621 )
2722
28- #"""
29- advance_process = (Process ("advance_proc" )
23+ advance_process = (MethodProcess ("advance_proc" )
3024 >> a .advance_state )
31- #ctx.wrap_and_add_command(advance_process.pure, name="run")
32- ctx .wrap_and_add_command (jit (advance_process .pure ), name = "run" )
33-
34- reset_process = (Process ("reset_proc" )
25+ reset_process = (MethodProcess ("reset_proc" )
3526 >> a .reset )
36- ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
37- #"""
38-
39- """
40- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
41- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
42- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
43- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
44- """
4527
4628 ## set up non-compiled utility commands
47- @Context .dynamicCommand
4829 def clamp (x ):
4930 a .j .set (x )
5031
@@ -54,15 +35,15 @@ def clamp(x):
5435 y_seq = jnp .asarray ([[0. , 1. , 0. , 0. , 0. , 1. , 0. ]], dtype = jnp .float32 )
5536
5637 outs = []
57- ctx . reset ()
38+ reset_process . run ()
5839 for ts in range (x_seq .shape [1 ]):
5940 x_t = jnp .array ([[x_seq [0 , ts ]]]) ## get data at time t
60- ctx . clamp (x_t )
61- ctx .run (t = ts * 1. , dt = dt )
62- outs .append (a .s .value )
41+ clamp (x_t )
42+ advance_process .run (t = ts * 1. , dt = dt )
43+ outs .append (a .s .get () )
6344 outs = jnp .concatenate (outs , axis = 1 )
6445
6546 ## output should equal input
6647 assert_array_equal (outs , y_seq )
6748
68- # test_sLIFCell1()
49+ test_sLIFCell1 ()
0 commit comments