Skip to content

Commit 21f43f8

Browse files
author
Alexander Ororbia
committed
attempted rewrite of bernoulli-cell
1 parent c4fe072 commit 21f43f8

File tree

5 files changed

+89
-39
lines changed

5 files changed

+89
-39
lines changed

ngclearn/components/base_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from ngclearn import Component, Compartment
44
from ngclearn import numpy as np
5-
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
6-
get_current_path
5+
#from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
6+
from ngcsimlib.utils import get_current_path
77
from ngcsimlib.logger import warn, critical
88
import matplotlib.pyplot as plt
99

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,22 @@
1-
from ngclearn import resolver, Component, Compartment
1+
#from ngclearn import resolver, Component, Compartment
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, random, jit
44
from ngclearn.utils import tensorstats
55
from functools import partial
66
from ngcsimlib.deprecators import deprecate_args
77
from ngcsimlib.logger import info, warn
88

9-
@jit
10-
def _update_times(t, s, tols):
11-
"""
12-
Updates time-of-last-spike (tols) variable.
13-
14-
Args:
15-
t: current time (a scalar/int value)
16-
17-
s: binary spike vector
18-
19-
tols: current time-of-last-spike variable
20-
21-
Returns:
22-
updated tols variable
23-
"""
24-
_tols = (1. - s) * tols + (s * t)
25-
return _tols
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
2612

2713
class BernoulliCell(JaxComponent):
2814
"""
2915
A Bernoulli cell that produces spikes by sampling a Bernoulli distribution
3016
on-the-fly (to produce data-scaled Bernoulli spike trains).
3117
3218
| --- Cell Input Compartments: ---
33-
| inputs - input (takes in external signals)
19+
| inputs - input (takes in external signals -- should be probabilities w/ values in [0,1])
3420
| --- Cell State Compartments: ---
3521
| key - JAX PRNG key
3622
| --- Cell Output Compartments: ---
@@ -41,10 +27,13 @@ class BernoulliCell(JaxComponent):
4127
name: the string name of this cell
4228
4329
n_units: number of cellular entities (neural population size)
30+
31+
batch_size: batch size dimension of this cell (Default: 1)
4432
"""
4533

4634
def __init__(self, name, n_units, batch_size=1, **kwargs):
47-
super().__init__(name, **kwargs)
35+
#super().__init__(name, **kwargs)
36+
super(JaxComponent, self).__init__(name, **kwargs)
4837

4938
## Layer Size Setup
5039
self.batch_size = batch_size
@@ -56,30 +45,24 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
5645
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
5746
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
5847

48+
@transition(output_compartments=["outputs", "tols", "key"])
5949
@staticmethod
60-
def _advance_state(t, key, inputs, tols):
50+
def advance_state(t, key, inputs, tols):
51+
## NOTE: should `inputs` be checked if bounded to [0,1]?
6152
key, *subkeys = random.split(key, 2)
6253
outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
63-
tols = _update_times(t, outputs, tols)
54+
# Updates time-of-last-spike (tols) variable:
55+
# output = s = binary spike vector
56+
# tols = current time-of-last-spike variable
57+
tols = (1. - outputs) * tols + (outputs * t)
6458
return outputs, tols, key
6559

66-
@resolver(_advance_state)
67-
def advance_state(self, outputs, tols, key):
68-
self.outputs.set(outputs)
69-
self.tols.set(tols)
70-
self.key.set(key)
71-
60+
@transition(output_compartments=["inputs", "outputs", "tols"])
7261
@staticmethod
73-
def _reset(batch_size, n_units):
62+
def reset(batch_size, n_units):
7463
restVals = jnp.zeros((batch_size, n_units))
7564
return restVals, restVals, restVals
7665

77-
@resolver(_reset)
78-
def reset(self, inputs, outputs, tols):
79-
self.inputs.set(inputs)
80-
self.outputs.set(outputs) #None
81-
self.tols.set(tols)
82-
8366
def save(self, directory, **kwargs):
8467
file_name = directory + "/" + self.name + ".npz"
8568
jnp.savez(file_name, key=self.key.value)

ngclearn/components/jaxComponent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import time
22
from jax import random
3-
from ngclearn import resolver, Component, Compartment
3+
#from ngclearn import resolver, Component, Compartment
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
46

57
class JaxComponent(Component):
68
"""
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import BernoulliCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
16+
def test_bernoulliCell1():
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
# in_dim = 9 # ... dimension of patch data ...
21+
# hid_dim = 9 # ... number of atoms in the dictionary matrix
22+
dt = 1. # ms
23+
T = 300 # ms # (OR) number of E-steps to take during inference
24+
# ---- build a sparse coding linear generative model with a Cauchy prior ----
25+
with Context("Circuit") as ctx:
26+
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
27+
28+
myProcess = (Process()
29+
>> a.advance_state)
30+
31+
ctx.wrap_and_add_command(myProcess.pure, name="run")
32+
33+
## create and compile core simulation commands
34+
reset_cmd, reset_args = ctx.compile_by_key(
35+
a, compile_key="reset"
36+
)
37+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
38+
advance_cmd, advance_args = ctx.compile_by_key(
39+
a,compile_key="advance_state"
40+
)
41+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="advance")
42+
43+
44+
## set up non-compiled utility commands
45+
@Context.dynamicCommand
46+
def clamp(x):
47+
a.inputs.set(x)
48+
49+
## input spike train
50+
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
51+
52+
outs = []
53+
ctx.reset()
54+
for ts in range(x_seq.shape[1]):
55+
x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t
56+
ctx.clamp(x_t)
57+
ctx.advance(t=ts*1., dt=1.)
58+
outs.append(a.outputs.value)
59+
outs = jnp.concatenate(outs, axis=1)
60+
61+
## output should equal input
62+
assert_array_equal(outs, x_seq)
63+
#print(outs)
64+
65+
test_bernoulliCell1()

tests/components/neurons/graded/test_rateCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ def clamp(x):
4848
circuit.advance(t=ts*1., dt=1.)
4949

5050
print(a.zF.value)
51-
# assertion here if needed!
51+
# assertion here if needed!

0 commit comments

Comments
 (0)