11# %%
22
33from jax import numpy as jnp , random , jit
4- from ngcsimlib . context import Context
4+
55import numpy as np
66np .random .seed (42 )
77from ngclearn .components import HebbianSynapse
8- from ngcsimlib .compilers import compile_command , wrap_command
9- from numpy .testing import assert_array_equal
108
11- from ngcsimlib .compilers .process import Process , transition
12- from ngcsimlib .component import Component
13- from ngcsimlib .compartment import Compartment
14- from ngcsimlib .context import Context
15- from ngcsimlib .utils .compartment import Get_Compartment_Batch
9+ from numpy .testing import assert_array_equal
10+ from ngclearn import Context , MethodProcess
1611
1712
1813def test_hebbianSynapse ():
@@ -29,37 +24,23 @@ def test_hebbianSynapse():
2924
3025 with Context (name ) as ctx :
3126 a = HebbianSynapse (
32- name = "a" ,
33- shape = shape ,
27+ name = "a" ,
28+ shape = shape ,
3429 resist_scale = resist_scale ,
3530 batch_size = batch_size ,
3631 prior = ("gaussian" , 0.01 )
3732 )
3833
39- advance_process = (Process ("advance_proc" ) >> a .advance_state )
40- ctx .wrap_and_add_command (jit (advance_process .pure ), name = "run" )
41- reset_process = (Process ("reset_proc" ) >> a .reset )
42- ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
43- evolve_process = (Process ("evolve_proc" ) >> a .evolve )
44- ctx .wrap_and_add_command (jit (evolve_process .pure ), name = "evolve" )
45-
46- # Compile and add commands
47- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
48- # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
49- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
50- # ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
51- # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
52- # ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
53-
54- @Context .dynamicCommand
34+ advance_process = (MethodProcess ("advance_proc" ) >> a .advance_state )
35+ reset_process = (MethodProcess ("reset_proc" ) >> a .reset )
36+ evolve_process = (MethodProcess ("evolve_proc" ) >> a .evolve )
37+
5538 def clamp_inputs (x ):
5639 a .inputs .set (x )
5740
58- @Context .dynamicCommand
5941 def clamp_pre (x ):
6042 a .pre .set (x )
6143
62- @Context .dynamicCommand
6344 def clamp_post (x ):
6445 a .post .set (x )
6546
@@ -70,16 +51,18 @@ def clamp_post(x):
7051 in_pre = jnp .ones ((1 , 10 )) * 1.0
7152 in_post = jnp .ones ((1 , 5 )) * 0.75
7253
73- ctx . reset ()
54+ reset_process . run ()
7455 clamp_pre (in_pre )
7556 clamp_post (in_post )
76- ctx .run (t = 1. * dt , dt = dt )
77- ctx . evolve (t = 1. * dt , dt = dt )
57+ advance_process .run (t = 1. * dt , dt = dt )
58+ evolve_process . run (t = 1. * dt , dt = dt )
7859
79- print (a .weights .value )
60+ print (a .weights .get () )
8061
8162 # Basic assertions to check learning dynamics
82- assert a .weights .value .shape == (10 , 5 ), ""
83- assert a .weights .value [0 , 0 ] == 0.5 , ""
63+ assert a .weights .get ().shape == (10 , 5 ), ""
64+ assert a .weights .get ()[0 , 0 ] == 0.5 , ""
65+
66+ test_hebbianSynapse ()
67+
8468
85- # test_hebbianSynapse()
0 commit comments