Skip to content

Commit 479d94a

Browse files
author
Alexander Ororbia
committed
bernoulli and poisson cells revised, unit-tested
1 parent bd5b88d commit 479d94a

File tree

4 files changed

+81
-62
lines changed

4 files changed

+81
-62
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
#from ngclearn import resolver, Component, Compartment
21
from ngclearn.components.jaxComponent import JaxComponent
3-
from jax import numpy as jnp, random, jit
2+
from jax import numpy as jnp, random
43
from ngclearn.utils import tensorstats
5-
from functools import partial
64
from ngcsimlib.deprecators import deprecate_args
75
from ngcsimlib.logger import info, warn
86

@@ -49,6 +47,11 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
4947
@staticmethod
5048
def advance_state(t, key, inputs, tols):
5149
## NOTE: should `inputs` be checked if bounded to [0,1]?
50+
# print(key)
51+
# print(t)
52+
# print(inputs.shape)
53+
# print(tols.shape)
54+
# print("-----")
5255
key, *subkeys = random.split(key, 3)
5356
outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
5457
# Updates time-of-last-spike (tols) variable:

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,12 @@
1-
from ngclearn import resolver, Component, Compartment
21
from ngclearn.components.jaxComponent import JaxComponent
3-
from jax import numpy as jnp, random, jit
2+
from jax import numpy as jnp, random
43
from ngclearn.utils import tensorstats
5-
from functools import partial
64
from ngcsimlib.deprecators import deprecate_args
75
from ngcsimlib.logger import info, warn
86

9-
@jit
10-
def _update_times(t, s, tols):
11-
"""
12-
Updates time-of-last-spike (tols) variable.
13-
14-
Args:
15-
t: current time (a scalar/int value)
16-
17-
s: binary spike vector
18-
19-
tols: current time-of-last-spike variable
20-
21-
Returns:
22-
updated tols variable
23-
"""
24-
_tols = (1. - s) * tols + (s * t)
25-
return _tols
7+
from ngcsimlib.compilers.process import transition
8+
#from ngcsimlib.component import Component
9+
from ngcsimlib.compartment import Compartment
2610

2711
class PoissonCell(JaxComponent):
2812
"""
@@ -79,33 +63,27 @@ def validate(self, dt=None, **validation_kwargs):
7963
)
8064
return valid
8165

66+
@transition(output_compartments=["outputs", "tols", "key"])
8267
@staticmethod
83-
def _advance_state(t, dt, target_freq, key, inputs, tols):
68+
def advance_state(t, dt, target_freq, key, inputs, tols):
8469
key, *subkeys = random.split(key, 2)
8570
pspike = inputs * (dt / 1000.) * target_freq
8671
eps = random.uniform(subkeys[0], inputs.shape, minval=0., maxval=1.,
8772
dtype=jnp.float32)
8873
outputs = (eps < pspike).astype(jnp.float32)
89-
tols = _update_times(t, outputs, tols)
90-
return outputs, tols, key
9174

92-
@resolver(_advance_state)
93-
def advance_state(self, outputs, tols, key):
94-
self.outputs.set(outputs)
95-
self.tols.set(tols)
96-
self.key.set(key)
75+
# Updates time-of-last-spike (tols) variable:
76+
# output = s = binary spike vector
77+
# tols = current time-of-last-spike variable
78+
tols = (1. - outputs) * tols + (outputs * t)
79+
return outputs, tols, key
9780

81+
@transition(output_compartments=["inputs", "outputs", "tols"])
9882
@staticmethod
99-
def _reset(batch_size, n_units):
83+
def reset(batch_size, n_units):
10084
restVals = jnp.zeros((batch_size, n_units))
10185
return restVals, restVals, restVals
10286

103-
@resolver(_reset)
104-
def reset(self, inputs, outputs, tols):
105-
self.inputs.set(inputs)
106-
self.outputs.set(outputs) #None
107-
self.tols.set(tols)
108-
10987
def save(self, directory, **kwargs):
11088
target_freq = (self.target_freq if isinstance(self.target_freq, float)
11189
else jnp.ones([[self.target_freq]]))
@@ -121,12 +99,9 @@ def load(self, directory, **kwargs):
12199
@classmethod
122100
def help(cls): ## component help function
123101
properties = {
124-
"cell_type": "PoissonCell - samples input to produce spikes, "
125-
"where dimension is a probability proportional to "
126-
"the dimension's magnitude/value/intensity and "
127-
"constrained by a maximum spike frequency (spikes "
128-
"follow "
129-
"a Poisson distribution)"
102+
"cell_type": "PoissonCell - samples input to produce spikes, where dimension is a probability proportional "
103+
"to the dimension's magnitude/value/intensity and constrained by a maximum spike frequency "
104+
"(spikes follow a Poisson distribution)"
130105
}
131106
compartment_props = {
132107
"inputs":

tests/components/input_encoders/test_bernoulliCell.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ def test_bernoulliCell1():
1717
## create seeding keys
1818
dkey = random.PRNGKey(1234)
1919
dkey, *subkeys = random.split(dkey, 6)
20-
# in_dim = 9 # ... dimension of patch data ...
21-
# hid_dim = 9 # ... number of atoms in the dictionary matrix
2220
dt = 1. # ms
23-
T = 300 # ms # (OR) number of E-steps to take during inference
24-
# ---- build a sparse coding linear generative model with a Cauchy prior ----
21+
#T = 300 # ms
22+
# ---- build a simple Bernoulli cell system ----
2523
with Context("Circuit") as ctx:
2624
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
2725

@@ -33,17 +31,6 @@ def test_bernoulliCell1():
3331
>> a.reset)
3432
ctx.wrap_and_add_command(reset_process.pure, name="reset")
3533

36-
# ## create and compile core simulation commands
37-
# reset_cmd, reset_args = ctx.compile_by_key(
38-
# a, compile_key="reset"
39-
# )
40-
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
41-
# advance_cmd, advance_args = ctx.compile_by_key(
42-
# a,compile_key="advance_state"
43-
# )
44-
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="advance")
45-
46-
4734
## set up non-compiled utility commands
4835
@Context.dynamicCommand
4936
def clamp(x):
@@ -53,16 +40,16 @@ def clamp(x):
5340
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
5441

5542
outs = []
56-
#ctx.reset()
43+
ctx.reset()
5744
for ts in range(x_seq.shape[1]):
5845
x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t
5946
ctx.clamp(x_t)
60-
ctx.run(t=ts*1.)#, dt=1.)
47+
ctx.run(t=ts*1., dt=dt)
6148
outs.append(a.outputs.value)
6249
outs = jnp.concatenate(outs, axis=1)
6350

6451
## output should equal input
6552
assert_array_equal(outs, x_seq)
66-
#print(outs)
53+
print(outs)
6754

6855
test_bernoulliCell1()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import PoissonCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
16+
def test_poissonCell():
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
# T = 300 # ms
22+
# ---- build a simple Poisson cell system ----
23+
with Context("Circuit") as ctx:
24+
a = PoissonCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
25+
26+
advance_process = (Process()
27+
>> a.advance_state)
28+
ctx.wrap_and_add_command(advance_process.pure, name="run")
29+
30+
reset_process = (Process()
31+
>> a.reset)
32+
ctx.wrap_and_add_command(reset_process.pure, name="reset")
33+
34+
## set up non-compiled utility commands
35+
@Context.dynamicCommand
36+
def clamp(x):
37+
a.inputs.set(x)
38+
39+
## input spike train
40+
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
41+
42+
outs = []
43+
ctx.reset()
44+
for ts in range(x_seq.shape[1]):
45+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
46+
ctx.clamp(x_t)
47+
ctx.run(t=ts * 1., dt=dt)
48+
outs.append(a.outputs.value)
49+
outs = jnp.concatenate(outs, axis=1)
50+
51+
## output should equal input
52+
assert_array_equal(outs, x_seq)
53+
54+
test_poissonCell()

0 commit comments

Comments
 (0)