Skip to content

Commit b98fd1a

Browse files
committed
update testing for graded neurons and input encoders
1 parent eeba012 commit b98fd1a

File tree

12 files changed

+78
-176
lines changed

12 files changed

+78
-176
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def advance_state(self, t):
4848
@compilable
4949
def reset(self):
5050
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
51-
not self.inputs.targeted and self.inputs.set(restVals)
51+
# BUG: the self.inputs here does not have the targeted field
52+
# NOTE: Quick workaround is to check if targeted is in the input or not
53+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
5254
self.outputs.set(restVals)
5355
self.tols.set(restVals)
5456

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def advance_state(self, t):
211211
@compilable
212212
def reset(self):
213213
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
214-
not self.inputs.targeted and self.inputs.set(restVals)
214+
# BUG: the self.inputs here does not have the targeted field
215+
# NOTE: Quick workaround is to check if targeted is in the input or not
216+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
215217
self.outputs.set(restVals)
216218
self.tols.set(restVals)
217219
self.mask.set(restVals)

ngclearn/components/input_encoders/phasorCell.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def advance_state(self, t, dt):
8888
@compilable
8989
def reset(self):
9090
restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
91-
not self.inputs.targeted and self.inputs.set(restVals)
91+
# BUG: the self.inputs here does not have the targeted field
92+
# NOTE: Quick workaround is to check if targeted is in the input or not
93+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
9294
self.outputs.set(restVals)
9395
self.tols.set(restVals)
9496
self.angles.set(restVals)

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def advance_state(self, t, dt):
6060
@compilable
6161
def reset(self):
6262
restVals = jnp.zeros((self.batch_size, self.n_units))
63-
if not self.inputs.targeted:
64-
self.inputs.set(restVals)
63+
# BUG: the self.inputs here does not have the targeted field
64+
# NOTE: Quick workaround is to check if targeted is in the input or not
65+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
6566
self.outputs.set(restVals)
6667
self.tols.set(restVals)
6768

tests/components/input_encoders/test_bernoulliCell.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1+
# %%
2+
13
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
34
import numpy as np
45
np.random.seed(42)
56
from ngclearn.components import BernoulliCell
67
#from ngcsimlib.compilers import compile_command, wrap_command
78
from numpy.testing import assert_array_equal
89

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngclearn.utils import JaxProcess
11-
from ngcsimlib.context import Context
12-
#from ngcsimlib.utils.compartment import Get_Compartment_Batch
10+
from ngclearn import MethodProcess, Context
1311

1412

1513
def test_bernoulliCell1():
@@ -23,28 +21,25 @@ def test_bernoulliCell1():
2321
with Context(name) as ctx:
2422
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
2523

26-
advance_process = (JaxProcess("advance_proc")
24+
advance_process = (MethodProcess("advance_proc")
2725
>> a.advance_state)
28-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
2926

30-
reset_process = (Process("reset_proc")
27+
reset_process = (MethodProcess("reset_proc")
3128
>> a.reset)
32-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3329

3430
## set up non-compiled utility commands
35-
@Context.dynamicCommand
3631
def clamp(x):
3732
a.inputs.set(x)
3833

3934
## input spike train
4035
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
4136

4237
outs = []
43-
ctx.reset()
38+
reset_process.run()
4439
for ts in range(x_seq.shape[1]):
4540
x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t
46-
ctx.clamp(x_t)
47-
ctx.run(t=ts*1., dt=dt)
41+
clamp(x_t)
42+
advance_process.run(t=ts*1., dt=dt)
4843
outs.append(a.outputs.value)
4944
outs = jnp.concatenate(outs, axis=1)
5045

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1+
# %%
2+
13
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
34
import numpy as np
45
np.random.seed(42)
56
from ngclearn.components import LatencyCell
6-
from ngcsimlib.compilers import compile_command, wrap_command
77
from numpy.testing import assert_array_equal
8-
9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14-
8+
from ngclearn import MethodProcess, Context
159

1610
def test_latencyCell1():
1711
name = "latency_ctx"
@@ -29,23 +23,19 @@ def test_latencyCell1():
2923
)
3024

3125
## create and compile core simulation commands
32-
advance_process = (Process("advance_proc")
26+
advance_process = (MethodProcess("advance_proc")
3327
>> a.advance_state)
34-
ctx.wrap_and_add_command(jit(advance_process.pure), name="advance")
35-
calc_spike_times_process = (Process("calc_sptimes_proc")
28+
calc_spike_times_process = (MethodProcess("calc_sptimes_proc")
3629
>> a.calc_spike_times)
37-
ctx.wrap_and_add_command(jit(calc_spike_times_process.pure), name="calc_spike_times")
38-
reset_process = (Process("reset_proc")
30+
reset_process = (MethodProcess("reset_proc")
3931
>> a.reset)
40-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
4132

4233
## set up non-compiled utility commands
43-
@Context.dynamicCommand
4434
def clamp(x):
4535
a.inputs.set(x)
4636

4737
## input spike train
48-
inputs = jnp.asarray([[0.02, 0.5, 1., 0.0]])
38+
x_t = jnp.asarray([[0.02, 0.5, 1., 0.0]])
4939

5040
targets = np.zeros((T, 4))
5141
targets[0, 2] = 1.
@@ -55,19 +45,19 @@ def clamp(x):
5545
targets = jnp.array(targets) ## gold-standard solution to check against
5646

5747
outs = []
58-
ctx.reset()
59-
ctx.clamp(inputs)
60-
ctx.calc_spike_times()
48+
reset_process.run()
49+
clamp(x_t)
50+
calc_spike_times_process.run()
6151
for ts in range(T):
62-
ctx.clamp(inputs)
63-
ctx.advance(t=ts * dt, dt=dt)
52+
clamp(x_t)
53+
advance_process.run(t=ts * dt, dt=dt)
6454
## naively extract simple statistics at time ts and print them to I/O
65-
s = a.outputs.value
55+
s = a.outputs.get()
6656
outs.append(s)
6757
#print(" {}: s {} ".format(ts, jnp.squeeze(s)))
6858
outs = jnp.concatenate(outs, axis=0)
6959

7060
## output should equal input
7161
assert_array_equal(outs, targets)
7262

73-
#test_latencyCell1()
63+
test_latencyCell1()

tests/components/input_encoders/test_phasorCell.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
32
import numpy as np
43
np.random.seed(42)
54
from ngclearn.components import PhasorCell
6-
#from ngcsimlib.compilers import compile_command, wrap_command
75
from numpy.testing import assert_array_equal
8-
9-
from ngcsimlib.compilers.process import Process, transition
10-
#from ngcsimlib.component import Component
11-
#from ngcsimlib.compartment import Compartment
12-
#from ngcsimlib.context import Context
13-
#from ngcsimlib.utils.compartment import Get_Compartment_Batch
6+
from ngclearn import MethodProcess, Context
147

158

169
def test_phasorCell1():
@@ -24,29 +17,26 @@ def test_phasorCell1():
2417
with Context(name) as ctx:
2518
a = PhasorCell(name="a", n_units=1, target_freq=1000., disable_phasor=True, key=subkeys[0])
2619

27-
advance_process = (Process("advance_proc")
20+
advance_process = (MethodProcess("advance_proc")
2821
>> a.advance_state)
29-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3022

31-
reset_process = (Process("reset_proc")
23+
reset_process = (MethodProcess("reset_proc")
3224
>> a.reset)
33-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3425

3526
## set up non-compiled utility commands
36-
@Context.dynamicCommand
3727
def clamp(x):
3828
a.inputs.set(x)
3929

4030
## input spike train
4131
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
4232

4333
outs = []
44-
ctx.reset()
34+
reset_process.run()
4535
for ts in range(x_seq.shape[1]):
4636
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
47-
ctx.clamp(x_t)
48-
ctx.run(t=ts * 1., dt=dt)
49-
outs.append(a.outputs.value)
37+
clamp(x_t)
38+
advance_process.run(t=ts * 1., dt=dt)
39+
outs.append(a.outputs.get())
5040
#print(a.outputs.value)
5141
outs = jnp.concatenate(outs, axis=1)
5242
#print(outs)

tests/components/input_encoders/test_poissonCell.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
from jax import numpy as jnp, random, jit
2-
from ngcsimlib.context import Context
32
import numpy as np
43
np.random.seed(42)
54
from ngclearn.components import PoissonCell
6-
from ngcsimlib.compilers import compile_command, wrap_command
75
from numpy.testing import assert_array_equal
86

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
13-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
7+
from ngclearn import MethodProcess, Context
148

159

1610
def test_poissonCell1():
@@ -24,29 +18,26 @@ def test_poissonCell1():
2418
with Context(name) as ctx:
2519
a = PoissonCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
2620

27-
advance_process = (Process("advance_proc")
21+
advance_process = (MethodProcess("advance_proc")
2822
>> a.advance_state)
29-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3023

31-
reset_process = (Process("reset_proc")
24+
reset_process = (MethodProcess("reset_proc")
3225
>> a.reset)
33-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
3426

3527
## set up non-compiled utility commands
36-
@Context.dynamicCommand
3728
def clamp(x):
3829
a.inputs.set(x)
3930

4031
## input spike train
4132
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
4233

4334
outs = []
44-
ctx.reset()
35+
reset_process.run()
4536
for ts in range(x_seq.shape[1]):
4637
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
47-
ctx.clamp(x_t)
48-
ctx.run(t=ts * 1., dt=dt)
49-
outs.append(a.outputs.value)
38+
clamp(x_t)
39+
advance_process.run(t=ts * 1., dt=dt)
40+
outs.append(a.outputs.get())
5041
outs = jnp.concatenate(outs, axis=1)
5142

5243
## output should equal input

tests/components/neurons/graded/test_bernoulliErrorCell.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
11
# %%
22

33
from jax import numpy as jnp, random, jit
4-
from ngcsimlib.context import Context
54
import numpy as np
65
np.random.seed(42)
76
from ngclearn.components import BernoulliErrorCell
8-
from ngcsimlib.compilers import compile_command, wrap_command
9-
from numpy.testing import assert_array_equal
10-
11-
from ngcsimlib.compilers.process import Process, transition
12-
from ngcsimlib.component import Component
13-
from ngcsimlib.compartment import Compartment
14-
from ngcsimlib.context import Context
15-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16-
7+
from ngclearn import MethodProcess, Context
178

189
def test_bernoulliErrorCell():
1910
np.random.seed(42)
@@ -25,21 +16,12 @@ def test_bernoulliErrorCell():
2516
a = BernoulliErrorCell(
2617
name="a", n_units=1, batch_size=1, input_logits=False, shape=None
2718
)
28-
advance_process = (Process("advance_proc") >> a.advance_state)
29-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
30-
reset_process = (Process("reset_proc") >> a.reset)
31-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
32-
33-
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34-
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35-
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36-
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
19+
advance_process = (MethodProcess("advance_proc") >> a.advance_state)
20+
reset_process = (MethodProcess("reset_proc") >> a.reset)
3721

38-
@Context.dynamicCommand
3922
def clamp(x):
4023
a.p.set(x)
4124

42-
@Context.dynamicCommand
4325
def clamp_target(x):
4426
a.target.set(x)
4527

@@ -50,13 +32,13 @@ def clamp_target(x):
5032
y_seq = jnp.asarray([[-2.8193381, -4976.9263, -2.1224928, -2939.0425, -1233.3916, -0.24662945, -708.30042, 0.28213939, 3550.8477, 1.3651246]], dtype=jnp.float32)
5133

5234
outs = []
53-
ctx.reset()
35+
reset_process.run()
5436
for ts in range(x_seq.shape[1]):
5537
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
56-
ctx.clamp(x_t)
38+
clamp(x_t)
5739
target_xt = jnp.array([[target_seq[0, ts]]])
58-
ctx.clamp_target(target_xt)
59-
ctx.run(t=ts * 1., dt=dt)
40+
clamp_target(target_xt)
41+
advance_process.run(t=ts * 1., dt=dt)
6042
outs.append(a.dp.value)
6143
outs = jnp.concatenate(outs, axis=1)
6244
# print(outs)

0 commit comments

Comments
 (0)