Skip to content

Commit 3c0ed36

Browse files
author
Alexander Ororbia
committed
refactored fn-cell and test passed
1 parent 4013bc0 commit 3c0ed36

File tree

3 files changed

+67
-76
lines changed

3 files changed

+67
-76
lines changed

ngclearn/components/neurons/spiking/adExCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
3232
dv_dt = _dfw_internal(j, v, w, a, tau_m, v_rest)
3333
return dv_dt
3434

35-
class AdExCell(JaxComponent):
35+
class AdExCell(JaxComponent): ## adaptive exponential integrate-and-fire cell
3636
"""
3737
The AdEx (adaptive exponential leaky integrate-and-fire) neuronal cell
3838
model; a two-variable model. This cell model iteratively evolves

ngclearn/components/neurons/spiking/fitzhughNagumoCell.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from jax import numpy as jnp, random, jit, nn
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
77
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
88
step_euler, step_rk2
99

10-
from ngcsimlib.compilers.process import transition
11-
#from ngcsimlib.component import Component
10+
from ngcsimlib.parser import compilable
1211
from ngcsimlib.compartment import Compartment
1312

1413

@@ -34,7 +33,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
3433
dv_dt = _dfw_internal(j, v, w, a, b, g, tau_m)
3534
return dv_dt
3635

37-
class FitzhughNagumoCell(JaxComponent):
36+
class FitzhughNagumoCell(JaxComponent): ## F-H cell
3837
"""
3938
The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification
4039
of the Hodgkin-Huxley (squid axon) model. This cell model iteratively evolves
@@ -103,10 +102,10 @@ class FitzhughNagumoCell(JaxComponent):
103102
at an increase in computational cost (and simulation time)
104103
"""
105104

106-
# Define Functions
107-
def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
108-
beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False,
109-
integration_type="euler", **kwargs):
105+
def __init__(
106+
self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, beta=0.8, gamma=3., v0=0., w0=0.,
107+
v_thr=1.07, spike_reset=False, integration_type="euler", **kwargs
108+
):
110109
super().__init__(name, **kwargs)
111110

112111
## Integration properties
@@ -115,7 +114,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
115114

116115
## Cell properties
117116
self.tau_m = tau_m
118-
self.R_m = resist_m
117+
self.resist_m = resist_m ## resistance R_m
119118
self.tau_w = tau_w
120119
self.alpha = alpha
121120
self.beta = beta
@@ -138,41 +137,44 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
138137
self.s = Compartment(restVals)
139138
self.tols = Compartment(restVals) ## time-of-last-spike
140139

141-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
142-
@staticmethod
143-
def advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
144-
beta, gamma, intgFlag, j, v, w, tols):
145-
j_mod = j * R_m
146-
if intgFlag == 1:
147-
v_params = (j_mod, w, alpha, beta, gamma, tau_m)
148-
_, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
149-
w_params = (j_mod, v, alpha, beta, gamma, tau_w)
150-
_, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
140+
@compilable
141+
def advance_state(self, t, dt):
142+
j_mod = self.j.get() * self.resist_m
143+
if self.intgFlag == 1:
144+
v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m)
145+
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
146+
w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w)
147+
_, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
151148
else: # integType == 0 (default -- Euler)
152-
v_params = (j_mod, w, alpha, beta, gamma, tau_m)
153-
_, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
154-
w_params = (j_mod, v, alpha, beta, gamma, tau_w)
155-
_, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
156-
s = (_v > v_thr) * 1.
149+
v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m)
150+
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
151+
w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w)
152+
_, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
153+
s = (_v > self.v_thr) * 1.
157154
v = _v
158155
w = _w
159156

160-
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
161-
v = v * (1. - s) + s * v0
162-
w = w * (1. - s) + s * w0
163-
tols = (1. - s) * tols + (s * t) ## update tols
164-
return j, v, w, s, tols
165-
166-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
167-
@staticmethod
168-
def reset(batch_size, n_units, v0, w0):
169-
restVals = jnp.zeros((batch_size, n_units))
170-
j = restVals # None
171-
v = restVals + v0
172-
w = restVals + w0
173-
s = restVals #+ 0
174-
tols = restVals #+ 0
175-
return j, v, w, s, tols
157+
if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions
158+
v = v * (1. - s) + s * self.v0
159+
w = w * (1. - s) + s * self.w0
160+
161+
## update time-of-last spike variable(s)
162+
self.tols.set((1. - s) * self.tols.get() + (s * t))
163+
164+
# self.j.set(j) ## j is not getting modified in these dynamics
165+
self.v.set(v)
166+
self.w.set(w)
167+
self.s.set(s)
168+
169+
@compilable
170+
def reset(self):
171+
restVals = jnp.zeros((self.batch_size, self.n_units))
172+
if not self.j.targeted:
173+
self.j.set(restVals)
174+
self.v.set(restVals + self.v0)
175+
self.w.set(restVals + self.w0)
176+
self.s.set(restVals)
177+
self.tols.set(restVals)
176178

177179
@classmethod
178180
def help(cls): ## component help function
@@ -197,8 +199,7 @@ def help(cls): ## component help function
197199
"resist_m": "Membrane resistance value",
198200
"tau_w": "Recovery variable time constant",
199201
"v_thr": "Base voltage threshold value",
200-
"spike_reset": "Should voltage/recover be snapped to initial "
201-
"condition(s) if spike emitted?",
202+
"spike_reset": "Should voltage/recover be snapped to initial condition(s) if spike emitted?",
202203
"alpha": "Dimensionless recovery variable shift factor `a",
203204
"beta": "Dimensionless recovery variable scale factor `b`",
204205
"gamma": "Power-term divisor constant",

tests/components/neurons/spiking/test_fitzhughNagumoCell.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
from jax import numpy as jnp, random, jit
22
from ngcsimlib.context import Context
33
import numpy as np
4-
54
np.random.seed(42)
6-
from ngclearn.components import FitzhughNagumoCell
7-
from ngcsimlib.compilers import compile_command, wrap_command
8-
from numpy.testing import assert_array_equal
95

10-
from ngcsimlib.compilers.process import Process, transition
11-
from ngcsimlib.component import Component
12-
from ngcsimlib.compartment import Compartment
13-
from ngcsimlib.context import Context
14-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
6+
from ngclearn import Context, MethodProcess
7+
from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
8+
from numpy.testing import assert_array_equal
159

1610

1711
def test_fitzhughNagumoCell1():
@@ -26,43 +20,39 @@ def test_fitzhughNagumoCell1():
2620
name="a", n_units=1, tau_m=1., resist_m=5., v_thr=2.1, key=subkeys[0]
2721
)
2822

29-
#"""
30-
advance_process = (Process("advance_proc")
23+
# """
24+
advance_process = (MethodProcess("advance_proc")
3125
>> a.advance_state)
32-
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
26+
# ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3427

35-
reset_process = (Process("reset_proc")
28+
reset_process = (MethodProcess("reset_proc")
3629
>> a.reset)
37-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38-
#"""
39-
40-
"""
41-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45-
"""
46-
30+
# ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
31+
# """
4732
## set up non-compiled utility commands
48-
@Context.dynamicCommand
49-
def clamp(x):
50-
a.j.set(x)
33+
# @Context.dynamicCommand
34+
# def clamp(x):
35+
# a.j.set(x)
36+
37+
def clamp(x):
38+
a.j.set(x)
5139

5240
## input spike train
5341
x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
5442
## desired output/epsp pulses
5543
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
5644

5745
outs = []
58-
ctx.reset()
46+
reset_process.run() # ctx.reset()
5947
for ts in range(x_seq.shape[1]):
60-
x_t = x_seq[:, ts:ts+1] ## get data at time t
61-
ctx.clamp(x_t)
62-
ctx.run(t=ts * 1., dt=dt)
63-
outs.append(a.s.value)
48+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
49+
clamp(x_t) # ctx.clamp(x_t)
50+
advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
51+
outs.append(a.s.get())
52+
6453
outs = jnp.concatenate(outs, axis=1)
65-
#print(outs)
54+
# print(outs)
55+
# print(y_seq)
6656

6757
## output should equal input
6858
assert_array_equal(outs, y_seq)

0 commit comments

Comments
 (0)