Skip to content

Commit 0ed77fa

Browse files
author
Alexander Ororbia
committed
revised fh-cell w/ unit test
1 parent 06d4a53 commit 0ed77fa

File tree

3 files changed

+102
-59
lines changed

3 files changed

+102
-59
lines changed

ngclearn/components/neurons/spiking/fitzhughNagumoCell.py

Lines changed: 27 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
1-
from jax import numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
31
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random, jit, nn
3+
from functools import partial
44
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
57
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
68
step_euler, step_rk2
79

8-
@jit
9-
def _update_times(t, s, tols):
10-
"""
11-
Updates time-of-last-spike (tols) variable.
12-
13-
Args:
14-
t: current time (a scalar/int value)
15-
16-
s: binary spike vector
17-
18-
tols: current time-of-last-spike variable
10+
from ngcsimlib.compilers.process import transition
11+
#from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
1913

20-
Returns:
21-
updated tols variable
22-
"""
23-
_tols = (1. - s) * tols + (s * t)
24-
return _tols
2514

2615
@jit
2716
def _dfv_internal(j, v, w, a, b, g, tau_m): ## raw voltage dynamics
@@ -45,25 +34,6 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
4534
dv_dt = _dfw_internal(j, v, w, a, b, g, tau_m)
4635
return dv_dt
4736

48-
@jit
49-
def _emit_spike(v, v_thr):
50-
s = (v > v_thr).astype(jnp.float32)
51-
return s
52-
53-
def _run_cell(dt, j, v, w, v_thr, tau_m, tau_w, a, b, g=3., integType=0):
54-
if integType == 1:
55-
v_params = (j, w, a, b, g, tau_m)
56-
_, _v = step_rk2(0., v, _dfv, dt, v_params) #_v = step_rk2(v, v_params, _dfv, dt)
57-
w_params = (j, v, a, b, g, tau_w)
58-
_, _w = step_rk2(0., w, _dfw, dt, w_params) #_w = step_rk2(w, w_params, _dfw, dt)
59-
else: # integType == 0 (default -- Euler)
60-
v_params = (j, w, a, b, g, tau_m)
61-
_, _v = step_euler(0., v, _dfv, dt, v_params) #_v = step_euler(v, v_params, _dfv, dt)
62-
w_params = (j, v, a, b, g, tau_w)
63-
_, _w = step_euler(0., w, _dfw, dt, w_params) #_w = step_euler(w, w_params, _dfw, dt)
64-
s = _emit_spike(_v, v_thr)
65-
return _v, _w, s
66-
6737
class FitzhughNagumoCell(JaxComponent):
6838
"""
6939
The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification
@@ -168,27 +138,34 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
168138
self.s = Compartment(restVals)
169139
self.tols = Compartment(restVals) ## time-of-last-spike
170140

141+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
171142
@staticmethod
172-
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
143+
def advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
173144
beta, gamma, intgFlag, j, v, w, tols):
174-
v, w, s = _run_cell(dt, j * R_m, v, w, v_thr, tau_m, tau_w, alpha, beta,
175-
gamma, intgFlag)
145+
j_mod = j * R_m
146+
if intgFlag == 1:
147+
v_params = (j_mod, w, alpha, beta, gamma, tau_m)
148+
_, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
149+
w_params = (j_mod, v, alpha, beta, gamma, tau_w)
150+
_, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
151+
else: # integType == 0 (default -- Euler)
152+
v_params = (j_mod, w, alpha, beta, gamma, tau_m)
153+
_, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
154+
w_params = (j_mod, v, alpha, beta, gamma, tau_w)
155+
_, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
156+
s = (_v > v_thr) * 1.
157+
v = _v
158+
w = _w
159+
176160
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
177161
v = v * (1. - s) + s * v0
178162
w = w * (1. - s) + s * w0
179-
tols = _update_times(t, s, tols)
163+
tols = (1. - s) * tols + (s * t) ## update tols
180164
return j, v, w, s, tols
181165

182-
@resolver(_advance_state)
183-
def advance_state(self, j, v, w, s, tols):
184-
self.j.set(j)
185-
self.w.set(w)
186-
self.v.set(v)
187-
self.s.set(s)
188-
self.tols.set(tols)
189-
166+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
190167
@staticmethod
191-
def _reset(batch_size, n_units, v0, w0):
168+
def reset(batch_size, n_units, v0, w0):
192169
restVals = jnp.zeros((batch_size, n_units))
193170
j = restVals # None
194171
v = restVals + v0
@@ -197,14 +174,6 @@ def _reset(batch_size, n_units, v0, w0):
197174
tols = restVals #+ 0
198175
return j, v, w, s, tols
199176

200-
@resolver(_reset)
201-
def reset(self, j, v, w, s, tols):
202-
self.j.set(j)
203-
self.v.set(v)
204-
self.w.set(w)
205-
self.s.set(s)
206-
self.tols.set(tols)
207-
208177
@classmethod
209178
def help(cls): ## component help function
210179
properties = {

tests/components/neurons/spiking/test_WTASCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_WTASCell1():
2323
# ---- build a simple Poisson cell system ----
2424
with Context(name) as ctx:
2525
a = WTASCell(
26-
name="a", n_units=1, tau_m=25., resist_m=1., key=subkeys[0]
26+
name="a", n_units=2, tau_m=25., resist_m=1., key=subkeys[0]
2727
)
2828

2929
#"""
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
5+
np.random.seed(42)
6+
from ngclearn.components import FitzhughNagumoCell
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
15+
16+
17+
def test_fitzhughNagumoCellCell1():
18+
name = "fh_ctx"
19+
## create seeding keys
20+
dkey = random.PRNGKey(1234)
21+
dkey, *subkeys = random.split(dkey, 6)
22+
dt = 0.1 #1. # ms
23+
# ---- build a simple Poisson cell system ----
24+
with Context(name) as ctx:
25+
a = FitzhughNagumoCell(
26+
name="a", n_units=1, tau_m=1., resist_m=5., v_thr=2.1, key=subkeys[0]
27+
)
28+
29+
#"""
30+
advance_process = (Process()
31+
>> a.advance_state)
32+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
34+
35+
reset_process = (Process()
36+
>> a.reset)
37+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38+
#"""
39+
40+
"""
41+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45+
"""
46+
47+
## set up non-compiled utility commands
48+
@Context.dynamicCommand
49+
def clamp(x):
50+
a.j.set(x)
51+
52+
## input spike train
53+
x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
54+
## desired output/epsp pulses
55+
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
56+
57+
outs = []
58+
volts = []
59+
recovs = []
60+
ctx.reset()
61+
for ts in range(x_seq.shape[1]):
62+
x_t = x_seq[:, ts:ts+1] ## get data at time t
63+
ctx.clamp(x_t)
64+
ctx.run(t=ts * 1., dt=dt)
65+
outs.append(a.s.value)
66+
volts.append(a.v.value)
67+
recovs.append(a.w.value)
68+
outs = jnp.concatenate(outs, axis=1)
69+
#print(outs)
70+
71+
## output should equal input
72+
assert_array_equal(outs, y_seq)
73+
74+
test_fitzhughNagumoCellCell1()

0 commit comments

Comments
 (0)