Skip to content

Commit 06d4a53

Browse files
author
Alexander Ororbia
committed
revised wtas-cell w/ unit test
1 parent b04822a commit 06d4a53

File tree

2 files changed

+94
-76
lines changed

2 files changed

+94
-76
lines changed

ngclearn/components/neurons/spiking/WTASCell.py

Lines changed: 24 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,16 @@
11
from jax import numpy as jnp, random, jit, nn
2-
from ngclearn import resolver, Component, Compartment
32
from ngclearn.components.jaxComponent import JaxComponent
3+
from jax import numpy as jnp, random, jit, nn
4+
from functools import partial
45
from ngclearn.utils import tensorstats
5-
from ngclearn.utils.model_utils import softmax
6-
7-
@jit
8-
def _update_times(t, s, tols):
9-
"""
10-
Updates time-of-last-spike (tols) variable.
11-
12-
Args:
13-
t: current time (a scalar/int value)
14-
15-
s: binary spike vector
16-
17-
tols: current time-of-last-spike variable
18-
19-
Returns:
20-
updated tols variable
21-
"""
22-
_tols = (1. - s) * tols + (s * t)
23-
return _tols
24-
@jit
25-
def _run_cell(dt, j, v, rfr, v_thr, tau_m, R_m, thr_gain=0.002, refract_T=0.):
26-
"""
27-
Runs leaky integrator neuronal dynamics
28-
29-
Args:
30-
dt: integration time constant (milliseconds, or ms)
31-
32-
j: electrical current value
33-
34-
v: membrane potential (voltage, in milliVolts or mV) value (at t)
35-
36-
rfr: refractory variable vector (one per neuronal cell)
37-
38-
v_thr: base voltage threshold value (in mV)
39-
40-
tau_m: cell membrane time constant
6+
from ngcsimlib.deprecators import deprecate_args
7+
from ngcsimlib.logger import info, warn
418

42-
R_m: cell membrane resistance
43-
44-
thr_gain: increment to be applied to threshold upon spike occurrence
45-
46-
refract_T: (relative) refractory time period (in ms; Default
47-
value is 1 ms)
9+
from ngcsimlib.compilers.process import transition
10+
#from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngclearn.utils.model_utils import softmax
4813

49-
Returns:
50-
voltage(t+dt), spikes, updated voltage thresholds, updated refactory variables
51-
"""
52-
mask = (rfr >= refract_T).astype(jnp.float32) ## check refractory period
53-
v = (j * R_m) * mask
54-
vp = softmax(v) # convert to Categorical (spike) probabilities
55-
#s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike
56-
s = (vp > v_thr).astype(jnp.float32) ## calculate action potential
57-
q = 1. ## Note: thr_gain ==> "rho_b"
58-
dthr = jnp.sum(s, axis=1, keepdims=True) - q
59-
v_thr = jnp.maximum(v_thr + dthr * thr_gain, 0.025) ## calc new threshold
60-
rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt
61-
return v, s, v_thr, rfr
6214

6315
class WTASCell(JaxComponent): ## winner-take-all spiking cell
6416
"""
@@ -136,22 +88,26 @@ def __init__(
13688
self.rfr = Compartment(restVals + self.refract_T)
13789
self.tols = Compartment(restVals) ## time-of-last-spike
13890

91+
@transition(output_compartments=["v", "s", "thr", "rfr", "tols"])
13992
@staticmethod
140-
def _advance_state(t, dt, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols):
141-
v, s, thr, rfr = _run_cell(dt, j, v, rfr, thr, tau_m, R_m, thr_gain, refract_T)
142-
tols = _update_times(t, s, tols) ## update tols
93+
def advance_state(t, dt, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols):
94+
mask = (rfr >= refract_T) * 1. ## check refractory period
95+
v = (j * R_m) * mask
96+
vp = softmax(v) # convert to Categorical (spike) probabilities
97+
# s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike
98+
s = (vp > thr) * 1. ## calculate action potential
99+
q = 1. ## Note: thr_gain ==> "rho_b"
100+
## increment threshold upon spike(s) occurrence
101+
dthr = jnp.sum(s, axis=1, keepdims=True) - q
102+
thr = jnp.maximum(thr + dthr * thr_gain, 0.025) ## calc new threshold
103+
rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt
104+
105+
tols = (1. - s) * tols + (s * t) ## update tols
143106
return v, s, thr, rfr, tols
144107

145-
@resolver(_advance_state)
146-
def advance_state(self, v, s, thr, rfr, tols):
147-
self.v.set(v)
148-
self.s.set(s)
149-
self.thr.set(thr)
150-
self.rfr.set(rfr)
151-
self.tols.set(tols)
152-
108+
@transition(output_compartments=["j", "v", "s", "rfr", "tols"])
153109
@staticmethod
154-
def _reset(batch_size, n_units, refract_T):
110+
def reset(batch_size, n_units, refract_T):
155111
restVals = jnp.zeros((batch_size, n_units))
156112
j = restVals #+ 0
157113
v = restVals #+ 0
@@ -160,14 +116,6 @@ def _reset(batch_size, n_units, refract_T):
160116
tols = restVals #+ 0
161117
return j, v, s, rfr, tols
162118

163-
@resolver(_reset)
164-
def reset(self, j, v, s, rfr, tols):
165-
self.j.set(j)
166-
self.v.set(v)
167-
self.s.set(s)
168-
self.rfr.set(rfr)
169-
self.tols.set(tols)
170-
171119
def save(self, directory, **kwargs):
172120
file_name = directory + "/" + self.name + ".npz"
173121
jnp.savez(file_name, threshold=self.thr.value)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
5+
np.random.seed(42)
6+
from ngclearn.components import WTASCell
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
15+
16+
17+
def test_WTASCell1():
18+
name = "wtas_ctx"
19+
## create seeding keys
20+
dkey = random.PRNGKey(1234)
21+
dkey, *subkeys = random.split(dkey, 6)
22+
dt = 1. # ms
23+
# ---- build a simple Poisson cell system ----
24+
with Context(name) as ctx:
25+
a = WTASCell(
26+
name="a", n_units=1, tau_m=25., resist_m=1., key=subkeys[0]
27+
)
28+
29+
#"""
30+
advance_process = (Process()
31+
>> a.advance_state)
32+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
34+
35+
reset_process = (Process()
36+
>> a.reset)
37+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38+
#"""
39+
40+
"""
41+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45+
"""
46+
47+
## set up non-compiled utility commands
48+
@Context.dynamicCommand
49+
def clamp(x):
50+
a.j.set(x)
51+
52+
## input spike train
53+
x_seq = jnp.asarray([[0., 1.], [0., 1.], [1., 0.], [1., 0.]], dtype=jnp.float32)
54+
## desired output/epsp pulses
55+
y_seq = x_seq
56+
57+
outs = []
58+
ctx.reset()
59+
for ts in range(x_seq.shape[0]):
60+
x_t = x_seq[ts:ts+1, :] ## get data at time t
61+
ctx.clamp(x_t)
62+
ctx.run(t=ts * 1., dt=dt)
63+
outs.append(a.s.value)
64+
outs = jnp.concatenate(outs, axis=0)
65+
#print(outs)
66+
#exit()
67+
## output should equal input
68+
assert_array_equal(outs, y_seq)
69+
70+
#test_WTASCell1()

0 commit comments

Comments
 (0)