Skip to content

Commit 6af9fd4

Browse files
author
Alexander Ororbia
committed
revised slif cell w/ unit-test; needed mod to diffeq
1 parent 0d24653 commit 6af9fd4

File tree

3 files changed

+109
-86
lines changed

3 files changed

+109
-86
lines changed

ngclearn/components/neurons/spiking/sLIFCell.py

Lines changed: 40 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,15 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
12
from jax import numpy as jnp, random, jit
23
from functools import partial
3-
from ngclearn import resolver, Component, Compartment
4-
from ngclearn.components.jaxComponent import JaxComponent
4+
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
57
from ngclearn.utils.diffeq.ode_utils import step_euler
68
from ngclearn.utils.surrogate_fx import secant_lif_estimator
7-
from ngclearn.utils import tensorstats
8-
9-
@jit
10-
def _update_times(t, s, tols):
11-
"""
12-
Updates time-of-last-spike (tols) variable.
13-
14-
Args:
15-
t: current time (a scalar/int value)
16-
17-
s: binary spike vector
18-
19-
tols: current time-of-last-spike variable
20-
21-
Returns:
22-
updated tols variable
23-
"""
24-
_tols = (1. - s) * tols + (s * t)
25-
return _tols
26-
27-
@partial(jit, static_argnums=[3,4])
28-
def _modify_current(j, spikes, inh_weights, R_m, inh_R):
29-
"""
30-
A simple function that modifies electrical current j via application of a
31-
scalar membrane resistance value and an approximate form of lateral inhibition.
32-
Note that if no inhibitory resistance is set (i.e., inh_R = 0), then no
33-
lateral inhibition is applied. Functionally, this routine carries out the
34-
following piecewise equation:
35-
36-
| j * R_m - [Wi * s(t-dt)] * inh_R, if inh_R > 0
37-
| j * R_m, otherwise
38-
39-
Args:
40-
j: electrical current value
41-
42-
spikes: previous binary spike vector (for t-dt)
43-
44-
inh_weights: lateral recurrent inhibitory synapses (typically should be
45-
chosen to be a scaled hollow matrix)
46-
47-
R_m: membrane resistance (to multiply/scale j by)
489

49-
inh_R: inhibitory resistance to scale lateral inhibitory current by; if
50-
inh_R = 0, NO lateral inhibitory pressure will be applied
51-
52-
Returns:
53-
modified electrical current value
54-
"""
55-
_j = j * R_m
56-
if inh_R > 0.:
57-
_j = _j - (jnp.matmul(spikes, inh_weights) * inh_R)
58-
return _j
10+
from ngcsimlib.compilers.process import transition
11+
#from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
5913

6014
@jit
6115
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
@@ -97,6 +51,7 @@ def _update_refract_and_spikes(dt, rfr, s, refract_T, sticky_spikes=False):
9751
_s = s * mask + (1. - mask)
9852
return _rfr, _s
9953

54+
@partial(jit, static_argnums=[6, 7, 8, 9, 10, 11])
10055
def _run_cell(dt, j, v, v_thr, tau_m, rfr, spike_fx, refract_T=1., thrGain=0.002,
10156
thrLeak=0.0005, rho_b = 0., sticky_spikes=False, v_min=None):
10257
"""
@@ -206,6 +161,8 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell
206161
a key setting used by Samadi et al., 2017
207162
208163
thr_jitter: scale of uniform jitter to add to initialization of thresholds
164+
165+
batch_size: batch size dimension of this cell (Default: 1)
209166
"""
210167

211168
# Define Functions
@@ -258,36 +215,44 @@ def __init__(self, name, n_units, tau_m, resist_m, thr, resist_inh=0.,
258215
self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s)
259216
self.surrogate = Compartment(restVals + 1.) ## surrogate signal
260217

218+
@transition(output_compartments=["j", "s", "tols", "v", "thr", "rfr", "surrogate"])
261219
@staticmethod
262-
def _advance_state(t, dt, inh_weights, R_m, inh_R, d_spike_fx, tau_m,
263-
spike_fx, refract_T, thrGain, thrLeak, rho_b,
264-
sticky_spikes, v_min, j, s, v, thr, rfr, tols):
265-
## run one step of Euler integration over neuronal dynamics
266-
j_curr = j
267-
## apply simplified inhibitory pressure
268-
j_curr = _modify_current(j_curr, s, inh_weights, R_m, inh_R)
269-
j = j_curr # None ## store electrical current
270-
surrogate = d_spike_fx(j_curr, c1=0.82, c2=0.08)
220+
def advance_state(
221+
t, dt, inh_weights, R_m, inh_R, d_spike_fx, tau_m, spike_fx, refract_T, thrGain,
222+
thrLeak, rho_b, sticky_spikes, v_min, j, s, v, thr, rfr, tols
223+
):
224+
#####################################################################################
225+
#The following 3 lines of code modify electrical current j via application of a
226+
#scalar membrane resistance value and an approximate form of lateral inhibition.
227+
#Functionally, this routine carries out the following piecewise equation:
228+
#| j * R_m - [Wi * s(t-dt)] * inh_R, if inh_R > 0
229+
#| j * R_m, otherwise
230+
#| where j: electrical current value, spikes: previous binary spike vector (for t-dt),
231+
# inh_weights: lateral recurrent inhibitory synapses (typically should be chosen
232+
# to be a scaled hollow matrix),
233+
#| R_m: membrane resistance (to multiply/scale j by),
234+
#| inh_R: inhibitory resistance to scale lateral inhibitory current by; if inh_R = 0,
235+
# NO lateral inhibitory pressure will be applied
236+
j = j * R_m
237+
if inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied
238+
j = j - (jnp.matmul(spikes, inh_weights) * inh_R)
239+
#####################################################################################
240+
241+
surrogate = d_spike_fx(j, c1=0.82, c2=0.08)
242+
#surrogate = d_spike_fx(j_curr, c1=0.82, c2=0.08)
243+
271244
v, s, thr, rfr = \
272-
_run_cell(dt, j_curr, v, thr, tau_m,
245+
_run_cell(dt, j, v, thr, tau_m,
273246
rfr, spike_fx, refract_T, thrGain, thrLeak,
274247
rho_b, sticky_spikes=sticky_spikes, v_min=v_min)
248+
275249
## update tols
276-
tols = _update_times(t, s, tols)
250+
tols = (1. - s) * tols + (s * t)
277251
return j, s, tols, v, thr, rfr, surrogate
278252

279-
@resolver(_advance_state)
280-
def advance_state(self, j, s, tols, v, thr, rfr, surrogate):
281-
self.j.set(j)
282-
self.s.set(s)
283-
self.tols.set(tols)
284-
self.thr.set(thr)
285-
self.rfr.set(rfr)
286-
self.surrogate.set(surrogate)
287-
self.v.set(v)
288-
253+
@transition(output_compartments=["j", "s", "tols", "v", "thr", "rfr", "surrogate"])
289254
@staticmethod
290-
def _reset(refract_T, thr_persist, threshold0, batch_size, n_units, thr):
255+
def reset(refract_T, thr_persist, threshold0, batch_size, n_units, thr):
291256
restVals = jnp.zeros((batch_size, n_units))
292257
voltage = restVals
293258
refract = restVals + refract_T
@@ -299,16 +264,6 @@ def _reset(refract_T, thr_persist, threshold0, batch_size, n_units, thr):
299264
thr = threshold0 + 0
300265
return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate
301266

302-
@resolver(_reset)
303-
def reset(self, j, s, tols, v, thr, rfr, surrogate):
304-
self.j.set(j)
305-
self.s.set(s)
306-
self.tols.set(tols)
307-
self.thr.set(thr)
308-
self.rfr.set(rfr)
309-
self.surrogate.set(surrogate)
310-
self.v.set(v)
311-
312267
def save(self, directory, **kwargs):
313268
file_name = directory + "/" + self.name + ".npz"
314269
if self.thr_persist == False:

ngclearn/utils/diffeq/ode_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine
5757
_x = x * x_scale + dx_dt * dt
5858
return _t, _x
5959

60-
60+
@partial(jit, static_argnums=(2, 3, 4, 5, ))
6161
def step_euler(t, x, dfx, dt, params, x_scale=1.):
6262
"""
6363
Iteratively integrates one step forward via the Euler method, i.e., a
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import SLIFCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
def test_sLIFCell1():
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
trace_increment = 0.1
21+
# ---- build a simple Poisson cell system ----
22+
with Context("Circuit") as ctx:
23+
a = SLIFCell(
24+
name="a", n_units=1, tau_m=50., resist_m=10., thr=0.3, key=subkeys[0]
25+
)
26+
27+
#"""
28+
advance_process = (Process()
29+
>> a.advance_state)
30+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31+
32+
reset_process = (Process()
33+
>> a.reset)
34+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
35+
#"""
36+
37+
"""
38+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
39+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
40+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
41+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
42+
"""
43+
44+
## set up non-compiled utility commands
45+
@Context.dynamicCommand
46+
def clamp(x):
47+
a.j.set(x)
48+
49+
## input spike train
50+
x_seq = jnp.asarray([[1., 1., 0., 0., 1., 1., 0.]], dtype=jnp.float32)
51+
## desired output/epsp pulses
52+
y_seq = jnp.asarray([[0., 1., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
53+
54+
outs = []
55+
ctx.reset()
56+
for ts in range(x_seq.shape[1]):
57+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
58+
ctx.clamp(x_t)
59+
ctx.run(t=ts * 1., dt=dt)
60+
outs.append(a.s.value)
61+
62+
outs = jnp.concatenate(outs, axis=1)
63+
#print(outs)
64+
65+
## output should equal input
66+
assert_array_equal(outs, y_seq)
67+
68+
test_sLIFCell1()

0 commit comments

Comments
 (0)