Skip to content

Commit 9ea587a

Browse files
author
Alexander Ororbia
committed
revised izh-cell w/ unit test
1 parent 0ed77fa commit 9ea587a

File tree

3 files changed

+108
-79
lines changed

3 files changed

+108
-79
lines changed

ngclearn/components/neurons/spiking/izhikevichCell.py

Lines changed: 35 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
1-
from jax import numpy as jnp, jit
2-
from ngclearn.utils import tensorstats
3-
from ngclearn import resolver, Component, Compartment
41
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random, jit, nn
3+
from functools import partial
4+
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
57
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
68
step_euler, step_rk2
79

8-
@jit
9-
def _update_times(t, s, tols):
10-
"""
11-
Updates time-of-last-spike (tols) variable.
12-
13-
Args:
14-
t: current time (a scalar/int value)
15-
16-
s: binary spike vector
10+
from ngcsimlib.compilers.process import transition
11+
#from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
1713

18-
tols: current time-of-last-spike variable
19-
20-
Returns:
21-
updated tols variable
22-
"""
23-
_tols = (1. - s) * tols + (s * t)
24-
return _tols
2514

2615
@jit
2716
def _dfv_internal(j, v, w, b, tau_m): ## raw voltage dynamics
@@ -55,39 +44,6 @@ def _post_process(s, _v, _w, v, w, c, d): ## internal post-processing routine
5544
w_next = _w * (1. - s) + s * (w + d)
5645
return v_next, w_next
5746

58-
@jit
59-
def _emit_spike(v, v_thr):
60-
s = (v > v_thr).astype(jnp.float32)
61-
return s
62-
63-
@jit
64-
def _modify_current(j, R_m):
65-
_j = j * R_m
66-
return _j
67-
68-
def _run_cell(dt, j, v, s, w, v_thr=30., tau_m=1., tau_w=50., b=0.2, c=-65., d=8.,
69-
R_m=1., integType=0):
70-
## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
71-
a = 1./tau_w ## we map time constant to variable "a" (a = 1/tau_w)
72-
_j = _modify_current(j, R_m)
73-
#_j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
74-
## check for spikes
75-
s = _emit_spike(v, v_thr)
76-
## for non-spikes, evolve according to dynamics
77-
if integType == 1:
78-
v_params = (_j, w, b, tau_m)
79-
_, _v = step_rk2(0., v, _dfv, dt, v_params) #_v = step_rk2(v, v_params, _dfv, dt)
80-
w_params = (_j, v, b, tau_w)
81-
_, _w = step_rk2(0., w, _dfw, dt, w_params) #_w = step_rk2(w, w_params, _dfw, dt)
82-
else: # integType == 0 (default -- Euler)
83-
v_params = (_j, w, b, tau_m)
84-
_, _v = step_euler(0., v, _dfv, dt, v_params) #_v = step_euler(v, v_params, _dfv, dt)
85-
w_params = (_j, v, b, tau_w)
86-
_, _w = step_euler(0., w, _dfw, dt, w_params) #_w = step_euler(w, w_params, _dfw, dt)
87-
## for spikes, snap to particular states
88-
_v, _w = _post_process(s, _v, _w, v, w, c, d)
89-
return _v, _w, s
90-
9147
class IzhikevichCell(JaxComponent): ## Izhikevich neuronal cell
9248
"""
9349
A spiking cell based on Izhikevich's model of neuronal dynamics. Note that
@@ -197,24 +153,38 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65.
197153
self.s = Compartment(restVals)
198154
self.tols = Compartment(restVals) ## time-of-last-spike
199155

156+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
200157
@staticmethod
201-
def _advance_state(t, dt, tau_m, tau_w, v_thr, coupling, v_reset, w_reset, R_m,
158+
def advance_state(t, dt, tau_m, tau_w, v_thr, coupling, v_reset, w_reset, R_m,
202159
intgFlag, j, v, w, s, tols):
203-
v, w, s = _run_cell(dt, j, v, s, w, v_thr=v_thr, tau_m=tau_m, tau_w=tau_w,
204-
b=coupling, c=v_reset, d=w_reset, R_m=R_m, integType=intgFlag)
205-
tols = _update_times(t, s, tols)
160+
## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
161+
a = 1. / tau_w ## we map time constant to variable "a" (a = 1/tau_w)
162+
_j = j * R_m
163+
# _j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
164+
## check for spikes
165+
s = (v > v_thr) * 1.
166+
## for non-spikes, evolve according to dynamics
167+
if intgFlag == 1:
168+
v_params = (_j, w, coupling, tau_m)
169+
_, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
170+
w_params = (_j, v, coupling, tau_w)
171+
_, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
172+
else: # integType == 0 (default -- Euler)
173+
v_params = (_j, w, coupling, tau_m)
174+
_, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
175+
w_params = (_j, v, coupling, tau_w)
176+
_, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
177+
## for spikes, snap to particular states
178+
_v, _w = _post_process(s, _v, _w, v, w, v_reset, w_reset)
179+
v = _v
180+
w = _w
181+
182+
tols = (1. - s) * tols + (s * t) ## update tols
206183
return j, v, w, s, tols
207184

208-
@resolver(_advance_state)
209-
def advance_state(self, j, v, w, s, tols):
210-
self.j.set(j)
211-
self.w.set(w)
212-
self.v.set(v)
213-
self.s.set(s)
214-
self.tols.set(tols)
215-
185+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
216186
@staticmethod
217-
def _reset(batch_size, n_units, v0, w0):
187+
def reset(batch_size, n_units, v0, w0):
218188
restVals = jnp.zeros((batch_size, n_units))
219189
j = restVals # None
220190
v = restVals + v0
@@ -223,14 +193,6 @@ def _reset(batch_size, n_units, v0, w0):
223193
tols = restVals #+ 0
224194
return j, v, w, s, tols
225195

226-
@resolver(_reset)
227-
def reset(self, j, v, w, s, tols):
228-
self.j.set(j)
229-
self.v.set(v)
230-
self.w.set(w)
231-
self.s.set(s)
232-
self.tols.set(tols)
233-
234196
@classmethod
235197
def help(cls): ## component help function
236198
properties = {

tests/components/neurons/spiking/test_fitzhughNagumoCell.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ngcsimlib.utils.compartment import Get_Compartment_Batch
1515

1616

17-
def test_fitzhughNagumoCellCell1():
17+
def test_fitzhughNagumoCell1():
1818
name = "fh_ctx"
1919
## create seeding keys
2020
dkey = random.PRNGKey(1234)
@@ -55,20 +55,16 @@ def clamp(x):
5555
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
5656

5757
outs = []
58-
volts = []
59-
recovs = []
6058
ctx.reset()
6159
for ts in range(x_seq.shape[1]):
6260
x_t = x_seq[:, ts:ts+1] ## get data at time t
6361
ctx.clamp(x_t)
6462
ctx.run(t=ts * 1., dt=dt)
6563
outs.append(a.s.value)
66-
volts.append(a.v.value)
67-
recovs.append(a.w.value)
6864
outs = jnp.concatenate(outs, axis=1)
6965
#print(outs)
7066

7167
## output should equal input
7268
assert_array_equal(outs, y_seq)
7369

74-
test_fitzhughNagumoCellCell1()
70+
#test_fitzhughNagumoCell1()
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
5+
np.random.seed(42)
6+
from ngclearn.components import IzhikevichCell
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
15+
16+
17+
def test_izhikevichCell1():
18+
name = "izh_ctx"
19+
## create seeding keys
20+
dkey = random.PRNGKey(1234)
21+
dkey, *subkeys = random.split(dkey, 6)
22+
dt = 1. # ms
23+
# ---- build a simple Poisson cell system ----
24+
with Context(name) as ctx:
25+
a = IzhikevichCell(
26+
name="a", n_units=1, tau_m=1., resist_m=4., v_thr=30., key=subkeys[0]
27+
)
28+
29+
#"""
30+
advance_process = (Process()
31+
>> a.advance_state)
32+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
34+
35+
reset_process = (Process()
36+
>> a.reset)
37+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38+
#"""
39+
40+
"""
41+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45+
"""
46+
47+
## set up non-compiled utility commands
48+
@Context.dynamicCommand
49+
def clamp(x):
50+
a.j.set(x)
51+
52+
## input spike train
53+
x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]], dtype=jnp.float32)
54+
## desired output/epsp pulses
55+
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
56+
57+
outs = []
58+
ctx.reset()
59+
for ts in range(x_seq.shape[1]):
60+
x_t = x_seq[:, ts:ts+1] ## get data at time t
61+
ctx.clamp(x_t)
62+
ctx.run(t=ts * 1., dt=dt)
63+
outs.append(a.s.value)
64+
print(a.v.value)
65+
outs = jnp.concatenate(outs, axis=1)
66+
print(outs)
67+
#exit()
68+
## output should equal input
69+
assert_array_equal(outs, y_seq)
70+
71+
#test_izhikevichCell1()

0 commit comments

Comments
 (0)