11# %%
22
33from jax import numpy as jnp , random , jit
4- from ngcsimlib .context import Context
54import numpy as np
65np .random .seed (42 )
76from ngclearn .components import HebbianPatchedSynapse
8- from ngcsimlib .compilers import compile_command , wrap_command
97from numpy .testing import assert_array_equal
108
11- from ngcsimlib .compilers .process import Process , transition
12- from ngcsimlib .component import Component
13- from ngcsimlib .compartment import Compartment
14- from ngcsimlib .context import Context
15- from ngcsimlib .utils .compartment import Get_Compartment_Batch
16-
9+ from ngclearn import MethodProcess , Context
1710
1811def test_hebbianPatchedSynapse ():
1912 np .random .seed (42 )
@@ -31,58 +24,45 @@ def test_hebbianPatchedSynapse():
3124
3225 with Context (name ) as ctx :
3326 a = HebbianPatchedSynapse (
34- name = "a" ,
35- shape = shape ,
36- n_sub_models = n_sub_models ,
27+ name = "a" ,
28+ shape = shape ,
29+ n_sub_models = n_sub_models ,
3730 stride_shape = stride_shape ,
3831 resist_scale = resist_scale ,
3932 batch_size = batch_size
4033 )
4134
42- advance_process = (Process ("advance_proc" ) >> a .advance_state )
43- ctx .wrap_and_add_command (jit (advance_process .pure ), name = "run" )
44- reset_process = (Process ("reset_proc" ) >> a .reset )
45- ctx .wrap_and_add_command (jit (reset_process .pure ), name = "reset" )
46- evolve_process = (Process ("evolve_proc" ) >> a .evolve )
47- ctx .wrap_and_add_command (jit (evolve_process .pure ), name = "evolve" )
48-
49- # Compile and add commands
50- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
51- # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
52- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
53- # ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
54- # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
55- # ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
56-
57- @Context .dynamicCommand
35+ advance_process = (MethodProcess ("advance_proc" ) >> a .advance_state )
36+ reset_process = (MethodProcess ("reset_proc" ) >> a .reset )
37+ evolve_process = (MethodProcess ("evolve_proc" ) >> a .evolve )
38+
5839 def clamp_inputs (x ):
5940 a .inputs .set (x )
6041
61- @Context .dynamicCommand
6242 def clamp_pre (x ):
6343 a .pre .set (x )
6444
65- @Context .dynamicCommand
6645 def clamp_post (x ):
6746 a .post .set (x )
6847
69- a .weights .set (jnp .ones ((12 , 12 )) * 0.5 )
48+ a .weights .set (jnp .ones ((12 , 12 )) * 0.5 )
7049
7150 in_pre = jnp .ones ((10 , 12 )) * 1.0
7251 in_post = jnp .ones ((10 , 12 )) * 0.75
7352
74- ctx . reset ()
53+ reset_process . run ()
7554 clamp_pre (in_pre )
7655 clamp_post (in_post )
77- ctx .run (t = 1. * dt , dt = dt )
78- ctx . evolve (t = 1. * dt , dt = dt )
56+ advance_process .run (t = 1. * dt , dt = dt )
57+ evolve_process . run (t = 1. * dt , dt = dt )
7958
80- print (a .weights .value )
59+ print (a .weights .get () )
8160
8261 # Basic assertions to check learning dynamics
83- assert a .weights .value .shape == (12 , 12 ), ""
84- assert a .weights .value [0 , 0 ] == 0.5 , ""
62+ assert a .weights .get ().shape == (12 , 12 ), ""
63+ assert a .weights .get ()[0 , 0 ] == 0.5 , ""
64+
8565
66+ test_hebbianPatchedSynapse ()
8667
8768
88- # test_hebbianPatchedSynapse()
0 commit comments