11# %%
22
33from jax import numpy as jnp , random , jit
4- from ngcsimlib .context import Context
54import numpy as np
65np .random .seed (42 )
76from ngclearn .components import BernoulliErrorCell
8- from ngcsimlib .compilers import compile_command , wrap_command
9- from numpy .testing import assert_array_equal
10-
11- from ngcsimlib .compilers .process import Process , transition
12- from ngcsimlib .component import Component
13- from ngcsimlib .compartment import Compartment
14- from ngcsimlib .context import Context
15- from ngcsimlib .utils .compartment import Get_Compartment_Batch
16-
7+ from ngclearn import MethodProcess , Context
178
189def test_bernoulliErrorCell ():
1910 np .random .seed (42 )
@@ -25,21 +16,12 @@ def test_bernoulliErrorCell():
2516 a = BernoulliErrorCell (
2617 name = "a" , n_units = 1 , batch_size = 1 , input_logits = False , shape = None
2718 )
28- advance_process = (Process ("advance_proc" ) >> a .advance_state )
29- ctx .wrap_and_add_command (jit (advance_process .pure ), name = "run" )
30- reset_process = (Process ("reset_proc" ) >> a .reset )
31- ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
32-
33- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34- # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36- # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
19+ advance_process = (MethodProcess ("advance_proc" ) >> a .advance_state )
20+ reset_process = (MethodProcess ("reset_proc" ) >> a .reset )
3721
38- @Context .dynamicCommand
3922 def clamp (x ):
4023 a .p .set (x )
4124
42- @Context .dynamicCommand
4325 def clamp_target (x ):
4426 a .target .set (x )
4527
@@ -50,13 +32,13 @@ def clamp_target(x):
5032 y_seq = jnp .asarray ([[- 2.8193381 , - 4976.9263 , - 2.1224928 , - 2939.0425 , - 1233.3916 , - 0.24662945 , - 708.30042 , 0.28213939 , 3550.8477 , 1.3651246 ]], dtype = jnp .float32 )
5133
5234 outs = []
53- ctx . reset ()
35+ reset_process . run ()
5436 for ts in range (x_seq .shape [1 ]):
5537 x_t = jnp .array ([[x_seq [0 , ts ]]]) ## get data at time t
56- ctx . clamp (x_t )
38+ clamp (x_t )
5739 target_xt = jnp .array ([[target_seq [0 , ts ]]])
58- ctx . clamp_target (target_xt )
59- ctx .run (t = ts * 1. , dt = dt )
40+ clamp_target (target_xt )
41+ advance_process .run (t = ts * 1. , dt = dt )
6042 outs .append (a .dp .value )
6143 outs = jnp .concatenate (outs , axis = 1 )
6244 # print(outs)
0 commit comments