@@ -25,20 +25,23 @@ def test_bernoulliCell1():
2525 with Context ("Circuit" ) as ctx :
2626 a = BernoulliCell (name = "a" , n_units = 1 , key = subkeys [0 ])
2727
28- myProcess = (Process ()
29- >> a .advance_state )
28+ advance_process = (Process ()
29+ >> a .advance_state )
30+ ctx .wrap_and_add_command (advance_process .pure , name = "run" )
3031
31- ctx .wrap_and_add_command (myProcess .pure , name = "run" )
32+ reset_process = (Process ()
33+ >> a .reset )
34+ ctx .wrap_and_add_command (reset_process .pure , name = "reset" )
3235
33- ## create and compile core simulation commands
34- reset_cmd , reset_args = ctx .compile_by_key (
35- a , compile_key = "reset"
36- )
37- ctx .add_command (wrap_command (jit (ctx .reset )), name = "reset" )
38- advance_cmd , advance_args = ctx .compile_by_key (
39- a ,compile_key = "advance_state"
40- )
41- ctx .add_command (wrap_command (jit (ctx .advance_state )), name = "advance" )
36+ # # # create and compile core simulation commands
37+ # reset_cmd, reset_args = ctx.compile_by_key(
38+ # a, compile_key="reset"
39+ # )
40+ # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
41+ # advance_cmd, advance_args = ctx.compile_by_key(
42+ # a,compile_key="advance_state"
43+ # )
44+ # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="advance")
4245
4346
4447 ## set up non-compiled utility commands
@@ -50,11 +53,11 @@ def clamp(x):
5053 x_seq = jnp .asarray ([[1. , 1. , 0. , 0. , 1. ]], dtype = jnp .float32 )
5154
5255 outs = []
53- ctx .reset ()
56+ # ctx.reset()
5457 for ts in range (x_seq .shape [1 ]):
5558 x_t = jnp .array ([[x_seq [0 ,ts ]]]) ## get data at time t
5659 ctx .clamp (x_t )
57- ctx .advance (t = ts * 1. , dt = 1. )
60+ ctx .run (t = ts * 1. ) # , dt=1.)
5861 outs .append (a .outputs .value )
5962 outs = jnp .concatenate (outs , axis = 1 )
6063
0 commit comments