Skip to content

Commit bd5b88d

Browse files
author
Alexander Ororbia
committed
edit to bern-cell
1 parent 7c67169 commit bd5b88d

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from ngcsimlib.deprecators import deprecate_args
77
from ngcsimlib.logger import info, warn
88

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
9+
from ngcsimlib.compilers.process import transition
10+
#from ngcsimlib.component import Component
1111
from ngcsimlib.compartment import Compartment
1212

1313
class BernoulliCell(JaxComponent):
@@ -49,7 +49,7 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
4949
@staticmethod
5050
def advance_state(t, key, inputs, tols):
5151
## NOTE: should `inputs` be checked if bounded to [0,1]?
52-
key, *subkeys = random.split(key, 2)
52+
key, *subkeys = random.split(key, 3)
5353
outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
5454
# Updates time-of-last-spike (tols) variable:
5555
# output = s = binary spike vector

tests/components/input_encoders/test_bernoulliCell.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,23 @@ def test_bernoulliCell1():
2525
with Context("Circuit") as ctx:
2626
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
2727

28-
myProcess = (Process()
29-
>> a.advance_state)
28+
advance_process = (Process()
29+
>> a.advance_state)
30+
ctx.wrap_and_add_command(advance_process.pure, name="run")
3031

31-
ctx.wrap_and_add_command(myProcess.pure, name="run")
32+
reset_process = (Process()
33+
>> a.reset)
34+
ctx.wrap_and_add_command(reset_process.pure, name="reset")
3235

33-
## create and compile core simulation commands
34-
reset_cmd, reset_args = ctx.compile_by_key(
35-
a, compile_key="reset"
36-
)
37-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
38-
advance_cmd, advance_args = ctx.compile_by_key(
39-
a,compile_key="advance_state"
40-
)
41-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="advance")
36+
# ## create and compile core simulation commands
37+
# reset_cmd, reset_args = ctx.compile_by_key(
38+
# a, compile_key="reset"
39+
# )
40+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
41+
# advance_cmd, advance_args = ctx.compile_by_key(
42+
# a,compile_key="advance_state"
43+
# )
44+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="advance")
4245

4346

4447
## set up non-compiled utility commands
@@ -50,11 +53,11 @@ def clamp(x):
5053
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
5154

5255
outs = []
53-
ctx.reset()
56+
#ctx.reset()
5457
for ts in range(x_seq.shape[1]):
5558
x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t
5659
ctx.clamp(x_t)
57-
ctx.advance(t=ts*1., dt=1.)
60+
ctx.run(t=ts*1.)#, dt=1.)
5861
outs.append(a.outputs.value)
5962
outs = jnp.concatenate(outs, axis=1)
6063

0 commit comments

Comments
 (0)