Skip to content

Commit 54e0873

Browse files
author
Alexander Ororbia
committed
fixed test_laplacianErrorCell and laplace-cell bug
1 parent 4606a1c commit 54e0873

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

tests/components/neurons/graded/test_laplacianErrorCell.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ngclearn.components import LaplacianErrorCell
77
from ngclearn import MethodProcess, Context
88

9-
def test_laplacianErrorCell():
9+
def test_laplacianErrorCell1():
1010
np.random.seed(42)
1111
name = "laplacian_error_ctx"
1212
dkey = random.PRNGKey(42)
@@ -16,17 +16,18 @@ def test_laplacianErrorCell():
1616
a = LaplacianErrorCell(
1717
name="a", n_units=1, batch_size=1, scale=1.0, shape=None
1818
)
19+
1920
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
2021
reset_process = (MethodProcess("reset_proc") >> a.reset)
2122

22-
def clamp_modulator(x):
23-
a.modulator.set(x)
23+
def clamp_modulator(x):
24+
a.modulator.set(x)
2425

25-
def clamp_shift(x):
26-
a.shift.set(x)
26+
def clamp_shift(x):
27+
a.shift.set(x)
2728

28-
def clamp_target(x):
29-
a.target.set(x)
29+
def clamp_target(x):
30+
a.target.set(x)
3031

3132
## input sequence
3233
modulator_seq = jnp.ones((1, 10))
@@ -49,13 +50,13 @@ def clamp_target(x):
4950
target_t = jnp.array([[target_seq[0, ts]]])
5051
clamp_target(target_t)
5152
advance_process.run(t=ts * 1., dt=dt)
52-
dshift_outs.append(a.dshift.value)
53+
dshift_outs.append(a.dshift.get())
5354
# print(f"a.L.value: {a.L.value}")
5455
# print(f"a.shift.value: {a.shift.value}")
5556
# print(f"a.target.value: {a.target.value}")
5657
# print(f"a.Scale.value: {a.Scale.value}")
5758
# print(f"a.mask.value: {a.mask.value}")
58-
L_outs.append(a.L.value)
59+
L_outs.append(a.L.get())
5960

6061
dshift_outs = jnp.concatenate(dshift_outs, axis=1)
6162
L_outs = jnp.array(L_outs)[None] # (1, 10)
@@ -68,3 +69,4 @@ def clamp_target(x):
6869
np.testing.assert_allclose(dshift_outs, expected_dshift, atol=1e-5)
6970
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
7071

72+
#test_laplacianErrorCell1()

0 commit comments

Comments
 (0)