Skip to content

Commit c01a619

Browse files
committed
update code
1 parent 03371ec commit c01a619

File tree

2 files changed

+23
-54
lines changed

2 files changed

+23
-54
lines changed

tests/components/neurons/graded/test_RateCell.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
# %%
22

33
from jax import numpy as jnp, random, jit
4-
from ngcsimlib.context import Context
54
import numpy as np
65
np.random.seed(42)
76
from ngclearn.components import RateCell
8-
from ngcsimlib.compilers import compile_command, wrap_command
97
from numpy.testing import assert_array_equal
108

11-
from ngcsimlib.compilers.process import Process, transition
12-
from ngcsimlib.component import Component
13-
from ngcsimlib.compartment import Compartment
14-
from ngcsimlib.context import Context
15-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
9+
from ngclearn import Context, MethodProcess
1610

1711

1812
def test_RateCell1():
@@ -26,17 +20,9 @@ def test_RateCell1():
2620
threshold=("none", 0.), integration_type="euler",
2721
batch_size=1, resist_scale=1., shape=None, is_stateful=True
2822
)
29-
advance_process = (Process("advance_proc") >> a.advance_state)
30-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31-
reset_process = (Process("reset_proc") >> a.reset)
32-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
23+
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
24+
reset_process = (MethodProcess("reset_proc") >> a.reset)
3325

34-
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
35-
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
36-
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
37-
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
38-
39-
@Context.dynamicCommand
4026
def clamp(x):
4127
a.j.set(x)
4228

@@ -46,11 +32,11 @@ def clamp(x):
4632
y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32)
4733

4834
outs = []
49-
ctx.reset()
35+
reset_process.run()
5036
for ts in range(x_seq.shape[1]):
5137
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
5238
ctx.clamp(x_t)
53-
ctx.run(t=ts * 1., dt=dt)
39+
advance_process.run(t=ts * 1., dt=dt)
5440
outs.append(a.z.value)
5541
outs = jnp.concatenate(outs, axis=1)
5642
# print(outs)
Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
# %%
22

33
from jax import numpy as jnp, random, jit
4-
from ngcsimlib.context import Context
4+
55
import numpy as np
66
np.random.seed(42)
77
from ngclearn.components import HebbianSynapse
8-
from ngcsimlib.compilers import compile_command, wrap_command
9-
from numpy.testing import assert_array_equal
108

11-
from ngcsimlib.compilers.process import Process, transition
12-
from ngcsimlib.component import Component
13-
from ngcsimlib.compartment import Compartment
14-
from ngcsimlib.context import Context
15-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
9+
from numpy.testing import assert_array_equal
10+
from ngclearn import Context, MethodProcess
1611

1712

1813
def test_hebbianSynapse():
@@ -29,37 +24,23 @@ def test_hebbianSynapse():
2924

3025
with Context(name) as ctx:
3126
a = HebbianSynapse(
32-
name="a",
33-
shape=shape,
27+
name="a",
28+
shape=shape,
3429
resist_scale=resist_scale,
3530
batch_size=batch_size,
3631
prior = ("gaussian", 0.01)
3732
)
3833

39-
advance_process = (Process("advance_proc") >> a.advance_state)
40-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
41-
reset_process = (Process("reset_proc") >> a.reset)
42-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
43-
evolve_process = (Process("evolve_proc") >> a.evolve)
44-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
45-
46-
# Compile and add commands
47-
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
48-
# ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
49-
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
50-
# ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
51-
# evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
52-
# ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
53-
54-
@Context.dynamicCommand
34+
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
35+
reset_process = (MethodProcess("reset_proc") >> a.reset)
36+
evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
37+
5538
def clamp_inputs(x):
5639
a.inputs.set(x)
5740

58-
@Context.dynamicCommand
5941
def clamp_pre(x):
6042
a.pre.set(x)
6143

62-
@Context.dynamicCommand
6344
def clamp_post(x):
6445
a.post.set(x)
6546

@@ -70,16 +51,18 @@ def clamp_post(x):
7051
in_pre = jnp.ones((1, 10)) * 1.0
7152
in_post = jnp.ones((1, 5)) * 0.75
7253

73-
ctx.reset()
54+
reset_process.run()
7455
clamp_pre(in_pre)
7556
clamp_post(in_post)
76-
ctx.run(t=1. * dt, dt=dt)
77-
ctx.evolve(t=1. * dt, dt=dt)
57+
advance_process.run(t=1. * dt, dt=dt)
58+
evolve_process.run(t=1. * dt, dt=dt)
7859

79-
print(a.weights.value)
60+
print(a.weights.get())
8061

8162
# Basic assertions to check learning dynamics
82-
assert a.weights.value.shape == (10, 5), ""
83-
assert a.weights.value[0, 0] == 0.5, ""
63+
assert a.weights.get().shape == (10, 5), ""
64+
assert a.weights.get()[0, 0] == 0.5, ""
65+
66+
test_hebbianSynapse()
67+
8468

85-
# test_hebbianSynapse()

0 commit comments

Comments
 (0)