66from ngclearn .components import LaplacianErrorCell
77from ngclearn import MethodProcess , Context
88
9- def test_laplacianErrorCell ():
9+ def test_laplacianErrorCell1 ():
1010 np .random .seed (42 )
1111 name = "laplacian_error_ctx"
1212 dkey = random .PRNGKey (42 )
@@ -16,17 +16,18 @@ def test_laplacianErrorCell():
1616 a = LaplacianErrorCell (
1717 name = "a" , n_units = 1 , batch_size = 1 , scale = 1.0 , shape = None
1818 )
19+
1920 advance_process = (MethodProcess ("advance_proc" ) >> a .advance_state )
2021 reset_process = (MethodProcess ("reset_proc" ) >> a .reset )
2122
22- def clamp_modulator (x ):
23- a .modulator .set (x )
23+ def clamp_modulator (x ):
24+ a .modulator .set (x )
2425
25- def clamp_shift (x ):
26- a .shift .set (x )
26+ def clamp_shift (x ):
27+ a .shift .set (x )
2728
28- def clamp_target (x ):
29- a .target .set (x )
29+ def clamp_target (x ):
30+ a .target .set (x )
3031
3132 ## input sequence
3233 modulator_seq = jnp .ones ((1 , 10 ))
@@ -49,13 +50,13 @@ def clamp_target(x):
4950 target_t = jnp .array ([[target_seq [0 , ts ]]])
5051 clamp_target (target_t )
5152 advance_process .run (t = ts * 1. , dt = dt )
52- dshift_outs .append (a .dshift .value )
53+ dshift_outs .append (a .dshift .get () )
5354 # print(f"a.L.value: {a.L.value}")
5455 # print(f"a.shift.value: {a.shift.value}")
5556 # print(f"a.target.value: {a.target.value}")
5657 # print(f"a.Scale.value: {a.Scale.value}")
5758 # print(f"a.mask.value: {a.mask.value}")
58- L_outs .append (a .L .value )
59+ L_outs .append (a .L .get () )
5960
6061 dshift_outs = jnp .concatenate (dshift_outs , axis = 1 )
6162 L_outs = jnp .array (L_outs )[None ] # (1, 10)
@@ -68,3 +69,4 @@ def clamp_target(x):
6869 np .testing .assert_allclose (dshift_outs , expected_dshift , atol = 1e-5 )
6970 np .testing .assert_allclose (L_outs , expected_L , atol = 1e-5 )
7071
72+ #test_laplacianErrorCell1()
0 commit comments