Skip to content

Commit aab5e3b

Browse files
author
Alexander Ororbia
committed
revised adex-cell w/ unit test, minor cleanup of quad-lif
1 parent db25673 commit aab5e3b

File tree

3 files changed

+103
-71
lines changed

3 files changed

+103
-71
lines changed

ngclearn/components/neurons/spiking/adExCell.py

Lines changed: 32 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,14 @@
1-
from jax import numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
31
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random, jit, nn
3+
from functools import partial
44
from ngclearn.utils import tensorstats
55
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
67
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
78
step_euler, step_rk2
8-
9-
@jit
10-
def _update_times(t, s, tols):
11-
"""
12-
Updates time-of-last-spike (tols) variable.
13-
14-
Args:
15-
t: current time (a scalar/int value)
16-
17-
s: binary spike vector
18-
19-
tols: current time-of-last-spike variable
20-
21-
Returns:
22-
updated tols variable
23-
"""
24-
_tols = (1. - s) * tols + (s * t)
25-
return _tols
9+
from ngcsimlib.compilers.process import transition
10+
#from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
2612

2713
@jit
2814
def _dfv_internal(j, v, w, tau_m, v_rest, sharpV, vT, R_m): ## raw voltage dynamics
@@ -46,30 +32,6 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
4632
dv_dt = _dfw_internal(j, v, w, a, tau_m, v_rest)
4733
return dv_dt
4834

49-
@jit
50-
def _emit_spike(v, v_thr):
51-
s = (v > v_thr).astype(jnp.float32)
52-
return s
53-
54-
#@partial(jit, static_argnums=[10])
55-
def _run_cell(dt, j, v, w, v_thr, tau_m, tau_w, a, b, sharpV, vT,
56-
v_rest, v_reset, R_m, integType=0):
57-
if integType == 1: ## RK-2/midpoint
58-
v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
59-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
60-
w_params = (j, v, a, tau_w, v_rest)
61-
_, _w = step_rk2(0., w, _dfw, dt, w_params)
62-
else: # integType == 0 (default -- Euler)
63-
v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
64-
_, _v = step_euler(0., v, _dfv, dt, v_params)
65-
w_params = (j, v, a, tau_w, v_rest)
66-
_, _w = step_euler(0., w, _dfw, dt, w_params)
67-
s = _emit_spike(_v, v_thr)
68-
## hyperpolarize/reset/snap variables
69-
_v = _v * (1. - s) + s * v_reset
70-
_w = _w * (1. - s) + s * (_w + b)
71-
return _v, _w, s
72-
7335
class AdExCell(JaxComponent):
7436
"""
7537
The AdEx (adaptive exponential leaky integrate-and-fire) neuronal cell
@@ -136,10 +98,10 @@ class AdExCell(JaxComponent):
13698
"""
13799

138100
@deprecate_args(v_thr="thr")
139-
def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
140-
v_sharpness=2., intrinsic_mem_thr=-55., thr=5., v_rest=-72.,
141-
v_reset=-75., a=0.1, b=0.75, v0=-70., w0=0.,
142-
integration_type="euler", batch_size=1, **kwargs):
101+
def __init__(
102+
self, name, n_units, tau_m=15., resist_m=1., tau_w=400., v_sharpness=2., intrinsic_mem_thr=-55., thr=5.,
103+
v_rest=-72., v_reset=-75., a=0.1, b=0.75, v0=-70., w0=0., integration_type="euler", batch_size=1, **kwargs
104+
):
143105
super().__init__(name, **kwargs)
144106

145107
## Integration properties
@@ -174,24 +136,32 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
174136
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
175137
units="ms") ## time-of-last-spike
176138

139+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
177140
@staticmethod
178-
def _advance_state(t, dt, tau_m, R_m, tau_w, thr, a, b, sharpV, vT,
179-
v_rest, v_reset, intgFlag, j, v, w, tols):
180-
v, w, s = _run_cell(dt, j, v, w, thr, tau_m, tau_w, a, b, sharpV, vT,
181-
v_rest, v_reset, R_m, intgFlag)
182-
tols = _update_times(t, s, tols)
141+
def advance_state(
142+
t, dt, tau_m, R_m, tau_w, thr, a, b, sharpV, vT, v_rest, v_reset, intgFlag, j, v, w, tols
143+
):
144+
if intgFlag == 1: ## RK-2/midpoint
145+
v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
146+
_, _v = step_rk2(0., v, _dfv, dt, v_params)
147+
w_params = (j, v, a, tau_w, v_rest)
148+
_, _w = step_rk2(0., w, _dfw, dt, w_params)
149+
else: # intgFlag == 0 (default -- Euler)
150+
v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
151+
_, _v = step_euler(0., v, _dfv, dt, v_params)
152+
w_params = (j, v, a, tau_w, v_rest)
153+
_, _w = step_euler(0., w, _dfw, dt, w_params)
154+
s = (_v > thr) * 1. ## emit spikes/pulses
155+
## hyperpolarize/reset/snap variables
156+
v = _v * (1. - s) + s * v_reset
157+
w = _w * (1. - s) + s * (_w + b)
158+
159+
tols = (1. - s) * tols + (s * t) ## update time-of-last spike variable(s)
183160
return j, v, w, s, tols
184161

185-
@resolver(_advance_state)
186-
def advance_state(self, j, v, w, s, tols):
187-
self.j.set(j)
188-
self.w.set(w)
189-
self.v.set(v)
190-
self.s.set(s)
191-
self.tols.set(tols)
192-
162+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
193163
@staticmethod
194-
def _reset(batch_size, n_units, v0, w0):
164+
def reset(batch_size, n_units, v0, w0):
195165
restVals = jnp.zeros((batch_size, n_units))
196166
j = restVals # None
197167
v = restVals + v0
@@ -200,14 +170,6 @@ def _reset(batch_size, n_units, v0, w0):
200170
tols = restVals #+ 0
201171
return j, v, w, s, tols
202172

203-
@resolver(_reset)
204-
def reset(self, j, v, w, s, tols):
205-
self.j.set(j)
206-
self.v.set(v)
207-
self.w.set(w)
208-
self.s.set(s)
209-
self.tols.set(tols)
210-
211173
@classmethod
212174
def help(cls): ## component help function
213175
properties = {

ngclearn/components/neurons/spiking/quadLIFCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class QuadLIFCell(LIFCell): ## quadratic integrate-and-fire cell
102102
sampled from the non-zero spikes detected
103103
""" ## batch_size arg?
104104

105-
@deprecate_args(thr_jitter=None)
105+
@deprecate_args(thr_jitter=None, critical_v="critical_V")
106106
def __init__(
107107
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_scale=-41.6, critical_v=1.,
108108
tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler",
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
5+
np.random.seed(42)
6+
from ngclearn.components import AdExCell
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
15+
16+
17+
def test_adExCell1():
18+
name = "adex_ctx"
19+
## create seeding keys
20+
dkey = random.PRNGKey(1234)
21+
dkey, *subkeys = random.split(dkey, 6)
22+
dt = 1. # ms
23+
# ---- build a simple Poisson cell system ----
24+
with Context(name) as ctx:
25+
a = AdExCell(
26+
name="a", n_units=1, tau_m=50., resist_m=30., thr=-66., key=subkeys[0]
27+
)
28+
29+
#"""
30+
advance_process = (Process()
31+
>> a.advance_state)
32+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
34+
35+
reset_process = (Process()
36+
>> a.reset)
37+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38+
#"""
39+
40+
"""
41+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45+
"""
46+
47+
## set up non-compiled utility commands
48+
@Context.dynamicCommand
49+
def clamp(x):
50+
a.j.set(x)
51+
52+
## input spike train
53+
x_seq = jnp.ones((1, 10))
54+
## desired output/epsp pulses
55+
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32)
56+
57+
outs = []
58+
ctx.reset()
59+
for ts in range(x_seq.shape[1]):
60+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
61+
ctx.clamp(x_t)
62+
ctx.run(t=ts * 1., dt=dt)
63+
outs.append(a.s.value)
64+
outs = jnp.concatenate(outs, axis=1)
65+
#print(outs)
66+
67+
## output should equal input
68+
assert_array_equal(outs, y_seq)
69+
70+
test_adExCell1()

0 commit comments

Comments
 (0)