Skip to content

Commit d4dfe38

Browse files
author
Alexander Ororbia
committed
revised lif-cell w/ unit-test
1 parent b077ee0 commit d4dfe38

File tree

2 files changed

+125
-79
lines changed

2 files changed

+125
-79
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 57 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,31 @@
1+
"""
12
from jax import numpy as jnp, random, jit, nn
23
from functools import partial
34
from ngclearn.utils import tensorstats
45
from ngcsimlib.deprecators import deprecate_args
56
from ngclearn import resolver, Component, Compartment
67
from ngclearn.components.jaxComponent import JaxComponent
8+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
9+
step_euler, step_rk2
10+
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
11+
triangular_estimator,
12+
straight_through_estimator)
13+
"""
14+
from ngclearn.components.jaxComponent import JaxComponent
15+
from jax import numpy as jnp, random, jit, nn
16+
from functools import partial
17+
from ngclearn.utils import tensorstats
18+
from ngcsimlib.deprecators import deprecate_args
19+
from ngcsimlib.logger import info, warn
720
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
821
step_euler, step_rk2
922
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
1023
triangular_estimator,
1124
straight_through_estimator)
1225

13-
@jit
14-
def _update_times(t, s, tols):
15-
"""
16-
Updates time-of-last-spike (tols) variable.
17-
18-
Args:
19-
t: current time (a scalar/int value)
20-
21-
s: binary spike vector
22-
23-
tols: current time-of-last-spike variable
24-
25-
Returns:
26-
updated tols variable
27-
"""
28-
_tols = (1. - s) * tols + (s * t)
29-
return _tols
26+
from ngcsimlib.compilers.process import transition
27+
#from ngcsimlib.component import Component
28+
from ngcsimlib.compartment import Compartment
3029

3130
@jit
3231
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
@@ -41,37 +40,6 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
4140
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay)
4241
return dv_dt
4342

44-
#@partial(jit, static_argnums=[7, 8, 9, 10, 11, 12])
45-
def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, tau_m, v_rest, v_reset,
46-
v_decay, refract_T, integType=0):
47-
### Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
48-
_v_thr = v_theta + v_thr ## calc present voltage threshold
49-
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
50-
## update voltage / membrane potential
51-
v_params = (j, rfr, tau_m, refract_T, v_rest, v_decay)
52-
if integType == 1:
53-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
54-
else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
55-
_, _v = step_euler(0., v, _dfv, dt, v_params)
56-
## obtain action potentials/spikes
57-
s = (_v > _v_thr).astype(jnp.float32)
58-
## update refractory variables
59-
_rfr = (rfr + dt) * (1. - s)
60-
## perform hyper-polarization of neuronal cells
61-
_v = _v * (1. - s) + s * v_reset
62-
63-
raw_s = s + 0 ## preserve un-altered spikes
64-
############################################################################
65-
## this is a spike post-processing step
66-
if skey is not None:
67-
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
68-
rS = s * random.uniform(skey, s.shape)
69-
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
70-
dtype=jnp.float32)
71-
s = s * (1. - m_switch) + rS * m_switch
72-
############################################################################
73-
return _v, s, raw_s, _rfr
74-
7543
#@partial(jit, static_argnums=[3, 4])
7644
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
7745
### Runs homeostatic threshold update dynamics one step (via Euler integration).
@@ -159,7 +127,7 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
159127
160128
lower_clamp_voltage: if True, this will ensure voltage never is below
161129
the value of `v_rest` (default: True)
162-
"""
130+
""" ## batch_size arg?
163131

164132
@deprecate_args(thr_jitter=None)
165133
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
@@ -220,41 +188,61 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
220188
units="ms") ## time-of-last-spike
221189
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
222190

191+
@transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
223192
@staticmethod
224-
def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T,
225-
thr, tau_theta, theta_plus, one_spike, lower_clamp_voltage,
226-
intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols):
193+
def advance_state(
194+
t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, theta_plus,
195+
one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
196+
):
227197
skey = None ## this is an empty dkey if single_spike mode turned off
228198
if one_spike:
229199
key, skey = random.split(key, 2)
230200
## run one integration step for neuronal dynamics
231201
j = j * resist_m
232-
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
233-
tau_m, v_rest, v_reset, v_decay,
234-
refract_T, intgFlag)
235-
surrogate = d_spike_fx(v, thr + thr_theta)
202+
############################################################################
203+
### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
204+
_v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
205+
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
206+
## update voltage / membrane potential
207+
v_params = (j, rfr, tau_m, refract_T, v_rest, v_decay)
208+
if intgFlag == 1:
209+
_, _v = step_rk2(0., v, _dfv, dt, v_params)
210+
else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
211+
_, _v = step_euler(0., v, _dfv, dt, v_params)
212+
## obtain action potentials/spikes
213+
s = (_v > _v_thr).astype(jnp.float32)
214+
## update refractory variables
215+
_rfr = (rfr + dt) * (1. - s)
216+
## perform hyper-polarization of neuronal cells
217+
_v = _v * (1. - s) + s * v_reset
218+
219+
raw_s = s + 0 ## preserve un-altered spikes
220+
############################################################################
221+
## this is a spike post-processing step
222+
if skey is not None:
223+
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
224+
rS = s * random.uniform(skey, s.shape)
225+
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
226+
dtype=jnp.float32)
227+
s = s * (1. - m_switch) + rS * m_switch
228+
############################################################################
229+
raw_spikes = raw_s
230+
v = _v
231+
rfr = _rfr
232+
233+
surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta)
236234
if tau_theta > 0.:
237235
## run one integration step for threshold dynamics
238236
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
239237
## update tols
240-
tols = _update_times(t, s, tols)
238+
tols = (1. - s) * tols + (s * t)
241239
if lower_clamp_voltage: ## ensure voltage never < v_rest
242240
v = jnp.maximum(v, v_rest)
243241
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
244242

245-
@resolver(_advance_state)
246-
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key, surrogate):
247-
self.v.set(v)
248-
self.s.set(s)
249-
self.s_raw.set(s_raw)
250-
self.rfr.set(rfr)
251-
self.thr_theta.set(thr_theta)
252-
self.tols.set(tols)
253-
self.key.set(key)
254-
self.surrogate.set(surrogate)
255-
243+
@transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])
256244
@staticmethod
257-
def _reset(batch_size, n_units, v_rest, refract_T):
245+
def reset(batch_size, n_units, v_rest, refract_T):
258246
restVals = jnp.zeros((batch_size, n_units))
259247
j = restVals #+ 0
260248
v = restVals + v_rest
@@ -266,16 +254,6 @@ def _reset(batch_size, n_units, v_rest, refract_T):
266254
surrogate = restVals + 1.
267255
return j, v, s, s_raw, rfr, tols, surrogate
268256

269-
@resolver(_reset)
270-
def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
271-
self.j.set(j)
272-
self.v.set(v)
273-
self.s.set(s)
274-
self.s_raw.set(s_raw)
275-
self.rfr.set(rfr)
276-
self.tols.set(tols)
277-
self.surrogate.set(surrogate)
278-
279257
def save(self, directory, **kwargs):
280258
## do a protected save of constants, depending on whether they are floats or arrays
281259
tau_m = (self.tau_m if isinstance(self.tau_m, float)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import LIFCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
def test_LIFCell1():
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
trace_increment = 0.1
21+
# ---- build a simple Poisson cell system ----
22+
with Context("Circuit") as ctx:
23+
a = LIFCell(
24+
name="a", n_units=1, tau_m=5., resist_m=30., key=subkeys[0]
25+
)
26+
27+
#"""
28+
advance_process = (Process()
29+
>> a.advance_state)
30+
#ctx.wrap_and_add_command(advance_process.pure, name="run")
31+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
32+
33+
reset_process = (Process()
34+
>> a.reset)
35+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
36+
#"""
37+
38+
"""
39+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
40+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
41+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
42+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
43+
"""
44+
45+
## set up non-compiled utility commands
46+
@Context.dynamicCommand
47+
def clamp(x):
48+
a.j.set(x)
49+
50+
## input spike train
51+
x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32)
52+
## desired output/epsp pulses
53+
y_seq = jnp.asarray([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]], dtype=jnp.float32)
54+
55+
outs = []
56+
ctx.reset()
57+
for ts in range(x_seq.shape[1]):
58+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
59+
ctx.clamp(x_t)
60+
ctx.run(t=ts * 1., dt=dt)
61+
outs.append(a.s.value)
62+
outs = jnp.concatenate(outs, axis=1)
63+
#print(outs)
64+
65+
## output should equal input
66+
assert_array_equal(outs, y_seq)
67+
68+
test_LIFCell1()

0 commit comments

Comments
 (0)