Skip to content

Commit b3c47a2

Browse files
author
Alexander Ororbia
committed
wrote+unit-test of hodgkin-huxley spike cell, minor tweaks/clean-up elsewhere
1 parent 098f3db commit b3c47a2

File tree

6 files changed

+329
-1
lines changed

6 files changed

+329
-1
lines changed

ngclearn/components/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .neurons.spiking.adExCell import AdExCell
1919
from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
2020
from .neurons.spiking.izhikevichCell import IzhikevichCell
21+
from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
2122
from .neurons.spiking.RAFCell import RAFCell
2223

2324
## point to transformer/operater component types

ngclearn/components/neurons/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
from .spiking.adExCell import AdExCell
1414
from .spiking.fitzhughNagumoCell import FitzhughNagumoCell
1515
from .spiking.izhikevichCell import IzhikevichCell
16+
from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
1617
from .spiking.RAFCell import RAFCell

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#@jit
1818
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
19-
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
19+
mask = (rfr >= refract_T) * 1. # get refractory mask
2020
## update voltage / membrane potential
2121
dv_dt = (v_rest - v) * v_decay + (j * mask)
2222
dv_dt = dv_dt * (1./tau_m)

ngclearn/components/neurons/spiking/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .fitzhughNagumoCell import FitzhughNagumoCell
99
from .izhikevichCell import IzhikevichCell
1010
from .RAFCell import RAFCell
11+
from .hodgkinHuxleyCell import HodgkinHuxleyCell
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random, jit, nn
3+
from functools import partial
4+
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
7+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
8+
step_euler, step_rk2
9+
10+
from ngcsimlib.compilers.process import transition
11+
#from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
14+
15+
def _calc_biophysical_constants(v): ## computes H-H biophysical constants (which are functions of voltage v)
16+
alpha_n_of_v = .01 * ((10 - v) / (jnp.exp((10. - v) / 10.) - 1.))
17+
beta_n_of_v = .125 * jnp.exp(-v / 80.)
18+
alpha_m_of_v = .1 * ((25 - v) / (jnp.exp((25. - v) / 10.) - 1.))
19+
beta_m_of_v = 4. * jnp.exp(-v / 18.)
20+
alpha_h_of_v = .07 * jnp.exp(-v / 20.)
21+
beta_h_of_v = 1. / (jnp.exp((30 - v) / 10.) + 1.)
22+
return alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v
23+
24+
def _dv_dt(t, v, j, m, n, h, tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L): ## ODE for membrane potential/voltage
25+
## C dv/dt = j - g_Na * m^3 * h * (v - v_Na) - g_K * n^4 * (v - v_K) - g_L * (v - v_L)
26+
term1 = g_Na * jnp.power(m, 3) * h * (v - v_Na)
27+
term2 = g_K * jnp.power(n, 4) * (v - v_K)
28+
term3 = g_L * (v - v_L)
29+
return (j - term1 - term2 - term3) * (1. / tau_v)
30+
31+
def dv_dt(t, v, params):
32+
j, m, n, h, tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L = params
33+
return _dv_dt(t, v, j, m, n, h, tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L)
34+
35+
def _dx_dt(t, x, alpha_x_of_v, beta_x_of_v): ## ODE for channel/gate
36+
## dx/dt = alpha_x(v) * (1 - x) - beta_x(v) * x
37+
return alpha_x_of_v * (1 - x) - beta_x_of_v * x
38+
39+
def dx_dt(t, x, params):
40+
alpha_x_of_v, beta_x_of_v = params
41+
return _dx_dt(t, x, alpha_x_of_v, beta_x_of_v)
42+
43+
class HodgkinHuxleyCell(JaxComponent): ## Hodgkin-Huxley spiking cell
44+
"""
45+
A spiking cell based the Hodgkin-Huxley (H-H) 1952 set of dynamics for describing the ionic mechanisms that underwrite
46+
the initiation and propagation of action potentials within a (giant) squid axon.
47+
48+
The four differential equations for adjusting this specific cell
49+
(for adjusting v, given current j, over time) is:
50+
51+
| tau_v dv/dt = j - g_Na * m^3 * h * (v - v_Na) - g_K * n^4 * (v - v_K) - g_L * (v - v_L)
52+
| dn/dt = alpha_n(v) * (1 - n) - beta_n(v) * n
53+
| dm/dt = alpha_m(v) * (1 - m) - beta_m(v) * m
54+
| dh/dt = alpha_h(v) * (1 - h) - beta_h(v) * h
55+
| where alpha_x(v) and beta_x(v) are functions that produce relevant biophysical constant values
56+
| depending on which gate/channel is being probed (i.e., x = n or m or h)
57+
58+
| --- Cell Input Compartments: ---
59+
| j - electrical current input (takes in external signals)
60+
| --- Cell State Compartments: ---
61+
| v - membrane potential/voltage state
62+
| n - dimensionless probabilities for potassium channel subunit activation
63+
| m - dimensionless probabilities for sodium channel subunit activation
64+
| h - dimensionless probabilities for sodium channel subunit inactivation
65+
| key - JAX PRNG key
66+
| --- Cell Output Compartments: ---
67+
| s - emitted binary spikes/action potentials
68+
| tols - time-of-last-spike
69+
70+
| References:
71+
| Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to
72+
| conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500.
73+
|
74+
| Kistler, Werner M., Wulfram Gerstner, and J. Leo van Hemmen. "Reduction of the Hodgkin-Huxley equations to a
75+
| single-variable threshold model." Neural computation 9.5 (1997): 1015-1045.
76+
77+
Args:
78+
name: the string name of this cell
79+
80+
n_units: number of cellular entities (neural population size)
81+
82+
tau_v: membrane time constant (Default: 1 ms)
83+
84+
resist_m: membrane resistance value
85+
86+
v_Na: sodium reversal potential
87+
88+
v_K: potassium reversal potential
89+
90+
v_L: leak reversal potential
91+
92+
g_Na: sodium (Na) conductance per unit area
93+
94+
g_K: potassium (K) conductance per unit area
95+
96+
g_L: leak conductance per unit area
97+
98+
thr: voltage/membrane threshold (to obtain action potentials in terms of binary spikes/pulses)
99+
100+
spike_reset: if True, once voltage crosses threshold, then dynamics of voltage and recovery are reset/snapped
101+
to `v_reset` which has a default value of 0 mV (Default: False)
102+
103+
v_reset: voltage value to reset to after a spike (in mV)
104+
(Default: 0 mV)
105+
"""
106+
107+
# Define Functions
108+
def __init__(
109+
self, name, n_units, tau_v, resist_m=1., v_Na=115., v_K=-35., v_L=10.6, g_Na=100., g_K=5., g_L=0.3, thr=4.,
110+
spike_reset=False, v_reset=0., **kwargs
111+
):
112+
super().__init__(name, **kwargs)
113+
114+
## membrane parameter setup (affects ODE integration)
115+
self.tau_v = tau_v ## membrane time constant
116+
self.R_m = resist_m ## resistance value
117+
self.spike_reset = spike_reset
118+
self.thr = thr # mV ## base value for threshold
119+
self.v_reset = v_reset ## base value to reset voltage to (if spike_reset = True)
120+
self.v_Na = v_Na #115. ## ENa
121+
self.v_K = v_K #-35. #-12. ## EK
122+
self.v_L = v_L #10.6 ## EKleak
123+
self.g_Na = g_Na #100. #120. ## gNa
124+
self.g_K = g_K #5. #36. ## gK
125+
self.g_L = g_L #0.3 ## gKleak
126+
127+
## Layer Size Setup
128+
self.batch_size = 1
129+
self.n_units = n_units
130+
131+
## Compartment setup
132+
restVals = jnp.zeros((self.batch_size, self.n_units))
133+
self.j = Compartment(restVals, display_name="Electrical input current")
134+
self.v = Compartment(restVals, display_name="Membrane potential/voltage")
135+
self.n = Compartment(restVals, display_name="Potassium channel subunit activation (probability)")
136+
self.m = Compartment(restVals, display_name="Sodium channel subunit activation (probability)")
137+
self.h = Compartment(restVals, display_name="Sodium channel subunit inactivation (probability)")
138+
self.s = Compartment(restVals, display_name="Spike pulse")
139+
self.tols = Compartment(restVals, display_name="Time-of-last-spike") ## time-of-last-spike
140+
141+
@transition(output_compartments=["v", "m", "n", "h", "s", "tols"])
142+
@staticmethod
143+
def advance_state(t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols):
144+
_j = j * R_m
145+
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v)
146+
## integrate voltage / membrane potential
147+
_, _v = step_euler(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
148+
## next, integrate different channels
149+
_, _n = step_euler(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
150+
_, _m = step_euler(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
151+
_, _h = step_euler(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
152+
## obtain action potentials/spikes/pulses
153+
s = (_v > thr) * 1.
154+
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
155+
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = (
156+
_calc_biophysical_constants(v * 0 + v_reset))
157+
_v = _v * (1. - s) + s * v_reset
158+
_n = _n * (1. - s) + s * (alpha_n_of_v / (alpha_n_of_v + beta_n_of_v))
159+
_m = _m * (1. - s) + s * (alpha_m_of_v / (alpha_m_of_v + beta_m_of_v))
160+
_h = _h * (1. - s) + s * (alpha_h_of_v / (alpha_h_of_v + beta_h_of_v))
161+
## transition to new state of (system of) variables
162+
v = _v
163+
m = _m
164+
n = _n
165+
h = _h
166+
tols = (1. - s) * tols + (s * t) ## update tols
167+
168+
return v, m, n, h, s, tols
169+
170+
@transition(output_compartments=["j", "v", "m", "n", "h", "s", "tols"])
171+
@staticmethod
172+
def reset(batch_size, n_units):
173+
restVals = jnp.zeros((batch_size, n_units))
174+
v = restVals # + 0
175+
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v)
176+
j = restVals #+ 0
177+
n = alpha_n_of_v / (alpha_n_of_v + beta_n_of_v)
178+
m = alpha_m_of_v / (alpha_m_of_v + beta_m_of_v)
179+
h = alpha_h_of_v / (alpha_h_of_v + beta_h_of_v)
180+
s = restVals #+ 0
181+
tols = restVals #+ 0
182+
return j, v, m, n, h, s, tols
183+
184+
def save(self, directory, **kwargs):
185+
file_name = directory + "/" + self.name + ".npz"
186+
#jnp.savez(file_name, threshold=self.thr.value)
187+
188+
def load(self, directory, seeded=False, **kwargs):
189+
file_name = directory + "/" + self.name + ".npz"
190+
data = jnp.load(file_name)
191+
#self.thr.set( data['threshold'] )
192+
193+
@classmethod
194+
def help(cls): ## component help function
195+
properties = {
196+
"cell_type": "WTASCell - evolves neurons according to winner-take-all "
197+
"spiking dynamics "
198+
}
199+
compartment_props = {
200+
"inputs":
201+
{"j": "External input electrical current"},
202+
"states":
203+
{"v": "Membrane potential/voltage at time t",
204+
"n": "Current state of potassium channel subunit activation",
205+
"m": "Current state of sodium channel subunit activation",
206+
"h": "Current state of sodium channel subunit inactivation",
207+
"key": "JAX PRNG key"},
208+
"outputs":
209+
{"s": "Emitted spikes/pulses at time t",
210+
"tols": "Time-of-last-spike"},
211+
}
212+
hyperparams = {
213+
"n_units": "Number of neuronal cells to model in this layer",
214+
"tau_v": "Cell membrane time constant",
215+
"resist_m": "Membrane resistance value",
216+
"thr": "Base voltage threshold value",
217+
"v_Na": "Sodium reversal potential",
218+
"v_K": "Potassium reversal potential",
219+
"v_L": "Leak reversal potential",
220+
"g_Na": "Sodium conductance per unit area",
221+
"g_K": "Potassium conductance per unit area",
222+
"g_L": "Leak conductance per unit area",
223+
"spike_reset": "Should this cell hyperpolarize by snapping to base values or not?",
224+
"v_reset": "Voltage value to reset to after a spike"
225+
}
226+
info = {cls.__name__: properties,
227+
"compartments": compartment_props,
228+
"dynamics": "tau_v dv/dt = j - g_Na * m^3 * h * (v - v_Na) - g_K * n^4 * (v - v_K) - g_L * (v - v_L)",
229+
"hyperparameters": hyperparams}
230+
return info
231+
232+
def __repr__(self):
233+
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
234+
maxlen = max(len(c) for c in comps) + 5
235+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
236+
for c in comps:
237+
stats = tensorstats(getattr(self, c).value)
238+
if stats is not None:
239+
line = [f"{k}: {v}" for k, v in stats.items()]
240+
line = ", ".join(line)
241+
else:
242+
line = "None"
243+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
244+
return lines
245+
246+
if __name__ == '__main__':
247+
from ngcsimlib.context import Context
248+
with Context("Bar") as bar:
249+
X = HodgkinHuxleyCell("X", 1, 1.)
250+
print(X)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
5+
np.random.seed(42)
6+
from ngclearn.components import HodgkinHuxleyCell
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_almost_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
15+
16+
import matplotlib.pyplot as plt
17+
18+
19+
def test_hodgkinHuxleyCell1():
20+
name = "hh_ctx"
21+
## create seeding keys
22+
dkey = random.PRNGKey(1234)
23+
dkey, *subkeys = random.split(dkey, 6)
24+
dt = 0.01 # 1. # ms
25+
26+
# ---- build a simple Poisson cell system ----
27+
with Context(name) as ctx:
28+
a = HodgkinHuxleyCell(
29+
name="a", n_units=1, tau_v=1., resist_m=1., key=subkeys[0]
30+
)
31+
32+
# """
33+
advance_process = (Process()
34+
>> a.advance_state)
35+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
36+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
37+
38+
reset_process = (Process()
39+
>> a.reset)
40+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
41+
# """
42+
43+
"""
44+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
45+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
46+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
47+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
48+
"""
49+
50+
## set up non-compiled utility commands
51+
@Context.dynamicCommand
52+
def clamp(x):
53+
a.j.set(x)
54+
55+
## input spike train
56+
x_seq = jnp.zeros((1, 20))
57+
y_seq = jnp.array(
58+
[[
59+
0.02414415, 0.04820144, 0.07217567, 0.09607048, 0.11988933, 0.14363553, 0.16731221, 0.19092241,
60+
0.21446899, 0.23795472, 0.26138224, 0.28475408, 0.30807265, 0.3313403, 0.35455925, 0.37773165,
61+
0.40085957, 0.42394499, 0.44698984, 0.46999594]], dtype=jnp.float32)
62+
63+
v = []
64+
ctx.reset()
65+
for ts in range(x_seq.shape[1]):
66+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
67+
ctx.clamp(x_t)
68+
ctx.run(t=ts * 1., dt=dt)
69+
v.append(a.v.value[0, 0])
70+
outs = jnp.array(v)
71+
diff = np.abs(outs - y_seq)
72+
## delta/error should be approximately zero
73+
assert_array_almost_equal(diff, diff * 0., decimal=6)
74+
75+
#test_hodgkinHuxleyCell1()

0 commit comments

Comments
 (0)