Skip to content

Commit 947d2cf

Browse files
author
Alexander Ororbia
committed
refactored and tests passed for izh and h-h cells
1 parent 37ecdcd commit 947d2cf

File tree

5 files changed

+150
-159
lines changed

5 files changed

+150
-159
lines changed

ngclearn/components/neurons/spiking/adExCell.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from ngclearn.utils import tensorstats
55
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
7-
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
8-
step_euler, step_rk2
7+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2
98

109
from ngcsimlib.parser import compilable
1110
from ngcsimlib.compartment import Compartment

ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
from jax import numpy as jnp, random, jit, nn
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
7-
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
8-
step_euler, step_rk2, step_rk4
7+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4
98

10-
from ngcsimlib.compilers.process import transition
11-
#from ngcsimlib.component import Component
9+
from ngcsimlib.parser import compilable
1210
from ngcsimlib.compartment import Compartment
1311

1412

@@ -113,7 +111,6 @@ class HodgkinHuxleyCell(JaxComponent): ## Hodgkin-Huxley spiking cell
113111
at an increase in computational cost (and simulation time)
114112
"""
115113

116-
# Define Functions
117114
def __init__(
118115
self, name, n_units, tau_v, resist_m=1., v_Na=115., v_K=-35., v_L=10.6, g_Na=100., g_K=5., g_L=0.3, thr=4.,
119116
spike_reset=False, v_reset=0., integration_type="euler", **kwargs
@@ -126,7 +123,7 @@ def __init__(
126123

127124
## cell properties / biophysical parameter setup (affects ODE integration)
128125
self.tau_v = tau_v ## membrane time constant
129-
self.R_m = resist_m ## resistance value
126+
self.resist_m = resist_m ## resistance value R_m
130127
self.spike_reset = spike_reset
131128
self.thr = thr # mV ## base value for threshold
132129
self.v_reset = v_reset ## base value to reset voltage to (if spike_reset = True)
@@ -151,38 +148,49 @@ def __init__(
151148
self.s = Compartment(restVals, display_name="Spike pulse")
152149
self.tols = Compartment(restVals, display_name="Time-of-last-spike") ## time-of-last-spike
153150

154-
@transition(output_compartments=["v", "m", "n", "h", "s", "tols"])
155-
@staticmethod
156-
def advance_state(
157-
t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag
158-
):
159-
_j = j * R_m
160-
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v)
151+
#@transition(output_compartments=["v", "m", "n", "h", "s", "tols"])
152+
#@staticmethod
153+
@compilable
154+
def advance_state(self, t, dt): #t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag
155+
_j = self.j.get() * self.resist_m
156+
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(self.v.get())
161157
## integrate voltage / membrane potential
162-
if intgFlag == 1: ## midpoint method
163-
_, _v = step_rk2(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
158+
if self.intgFlag == 1: ## midpoint method
159+
_, _v = step_rk2(
160+
0., self.v.get(), dv_dt, dt,
161+
(_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K,
162+
self.g_L, self.v_Na, self.v_K, self.v_L)
163+
)
164164
## next, integrate different channels
165-
_, _n = step_rk2(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
166-
_, _m = step_rk2(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
167-
_, _h = step_rk2(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
168-
elif intgFlag == 4: ## Runge-Kutta 4th order
169-
_, _v = step_rk4(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
165+
_, _n = step_rk2(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
166+
_, _m = step_rk2(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
167+
_, _h = step_rk2(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
168+
elif self.intgFlag == 4: ## Runge-Kutta 4th order
169+
_, _v = step_rk4(
170+
0., self.v.get(), dv_dt, dt,
171+
(_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K,
172+
self.g_L, self.v_Na, self.v_K, self.v_L)
173+
)
170174
## next, integrate different channels
171-
_, _n = step_rk4(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
172-
_, _m = step_rk4(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
173-
_, _h = step_rk4(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
175+
_, _n = step_rk4(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
176+
_, _m = step_rk4(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
177+
_, _h = step_rk4(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
174178
else: # integType == 0 (default -- Euler)
175-
_, _v = step_euler(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
179+
_, _v = step_euler(
180+
0., self.v.get(), dv_dt, dt,
181+
(_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K,
182+
self.g_L, self.v_Na, self.v_K, self.v_L)
183+
)
176184
## next, integrate different channels
177-
_, _n = step_euler(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
178-
_, _m = step_euler(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
179-
_, _h = step_euler(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
185+
_, _n = step_euler(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
186+
_, _m = step_euler(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
187+
_, _h = step_euler(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
180188
## obtain action potentials/spikes/pulses
181-
s = (_v > thr) * 1.
182-
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
189+
s = (_v > self.thr) * 1.
190+
if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions
183191
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = (
184-
_calc_biophysical_constants(v * 0 + v_reset))
185-
_v = _v * (1. - s) + s * v_reset
192+
_calc_biophysical_constants(self.v.get() * 0 + self.v_reset))
193+
_v = _v * (1. - s) + s * self.v_reset
186194
_n = _n * (1. - s) + s * (alpha_n_of_v / (alpha_n_of_v + beta_n_of_v))
187195
_m = _m * (1. - s) + s * (alpha_m_of_v / (alpha_m_of_v + beta_m_of_v))
188196
_h = _h * (1. - s) + s * (alpha_h_of_v / (alpha_h_of_v + beta_h_of_v))
@@ -191,32 +199,40 @@ def advance_state(
191199
m = _m
192200
n = _n
193201
h = _h
194-
tols = (1. - s) * tols + (s * t) ## update tols
202+
## update time-of-last spike variable(s)
203+
self.tols.set((1. - s) * self.tols.get() + (s * t))
195204

196-
return v, m, n, h, s, tols
205+
self.v.set(v)
206+
self.m.set(m)
207+
self.n.set(n)
208+
self.h.set(h)
209+
self.s.set(s)
197210

198-
@transition(output_compartments=["j", "v", "m", "n", "h", "s", "tols"])
199-
@staticmethod
200-
def reset(batch_size, n_units):
201-
restVals = jnp.zeros((batch_size, n_units))
211+
@compilable
212+
def reset(self):
213+
restVals = jnp.zeros((self.batch_size, self.n_units))
202214
v = restVals # + 0
203215
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v)
204-
j = restVals #+ 0
216+
if not self.j.targeted:
217+
self.j.set(restVals)
205218
n = alpha_n_of_v / (alpha_n_of_v + beta_n_of_v)
206219
m = alpha_m_of_v / (alpha_m_of_v + beta_m_of_v)
207220
h = alpha_h_of_v / (alpha_h_of_v + beta_h_of_v)
208-
s = restVals #+ 0
209-
tols = restVals #+ 0
210-
return j, v, m, n, h, s, tols
211-
212-
def save(self, directory, **kwargs):
213-
file_name = directory + "/" + self.name + ".npz"
214-
#jnp.savez(file_name, threshold=self.thr.value)
215-
216-
def load(self, directory, seeded=False, **kwargs):
217-
file_name = directory + "/" + self.name + ".npz"
218-
data = jnp.load(file_name)
219-
#self.thr.set( data['threshold'] )
221+
self.v.set(v)
222+
self.n.set(n)
223+
self.m.set(m)
224+
self.h.set(h)
225+
self.s.set(restVals)
226+
self.tols.set(restVals)
227+
228+
# def save(self, directory, **kwargs):
229+
# file_name = directory + "/" + self.name + ".npz"
230+
# #jnp.savez(file_name, threshold=self.thr.value)
231+
#
232+
# def load(self, directory, seeded=False, **kwargs):
233+
# file_name = directory + "/" + self.name + ".npz"
234+
# data = jnp.load(file_name)
235+
# #self.thr.set( data['threshold'] )
220236

221237
@classmethod
222238
def help(cls): ## component help function
@@ -258,7 +274,7 @@ def help(cls): ## component help function
258274
return info
259275

260276
def __repr__(self):
261-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
277+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
262278
maxlen = max(len(c) for c in comps) + 5
263279
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
264280
for c in comps:

ngclearn/components/neurons/spiking/izhikevichCell.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22
from jax import numpy as jnp, random, jit, nn
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
7-
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
8-
step_euler, step_rk2
7+
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2
98

10-
from ngcsimlib.compilers.process import transition
11-
#from ngcsimlib.component import Component
9+
from ngcsimlib.parser import compilable
1210
from ngcsimlib.compartment import Compartment
1311

14-
1512
@jit
1613
def _dfv_internal(j, v, w, b, tau_m): ## raw voltage dynamics
1714
## (v^2 * 0.04 + v * 5 + 140 - u + j) * a, where a = (1./tau_m) (w = u)
@@ -119,17 +116,16 @@ class IzhikevichCell(JaxComponent): ## Izhikevich neuronal cell
119116
at an increase in computational cost (and simulation time)
120117
"""
121118

122-
# Define Functions
123119
def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65.,
124120
tau_w=50., w_reset=8., coupling_factor=0.2, v0=-65., w0=-14.,
125121
integration_type="euler", **kwargs):
126122
super().__init__(name, **kwargs)
127123

128124
## Cell properties
129-
self.R_m = resist_m
125+
self.resist_m = resist_m ## resistance R_m
130126
self.tau_m = tau_m
131127
self.tau_w = tau_w
132-
self.coupling = coupling_factor
128+
self.coupling_factor = coupling_factor
133129
self.v_reset = v_reset
134130
self.w_reset = w_reset
135131

@@ -153,45 +149,47 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65.
153149
self.s = Compartment(restVals)
154150
self.tols = Compartment(restVals) ## time-of-last-spike
155151

156-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
157-
@staticmethod
158-
def advance_state(t, dt, tau_m, tau_w, v_thr, coupling, v_reset, w_reset, R_m,
159-
intgFlag, j, v, w, s, tols):
152+
@compilable
153+
def advance_state(self, t, dt):
160154
## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
161-
a = 1. / tau_w ## we map time constant to variable "a" (a = 1/tau_w)
162-
_j = j * R_m
155+
a = 1. / self.tau_w ## we map time constant to variable "a" (a = 1/tau_w)
156+
_j = self.j.get() * self.resist_m
163157
# _j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
164158
## check for spikes
165-
s = (v > v_thr) * 1.
159+
s = (self.v.get() > self.v_thr) * 1.
166160
## for non-spikes, evolve according to dynamics
167-
if intgFlag == 1:
168-
v_params = (_j, w, coupling, tau_m)
169-
_, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
170-
w_params = (_j, v, coupling, tau_w)
171-
_, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
161+
if self.intgFlag == 1:
162+
v_params = (_j, self.w.get(), self.coupling_factor, self.tau_m)
163+
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
164+
w_params = (_j, self.v.get(), self.coupling_factor, self.tau_w)
165+
_, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
172166
else: # integType == 0 (default -- Euler)
173-
v_params = (_j, w, coupling, tau_m)
174-
_, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
175-
w_params = (_j, v, coupling, tau_w)
176-
_, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
167+
v_params = (_j, self.w.get(), self.coupling_factor, self.tau_m)
168+
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
169+
w_params = (_j, self.v.get(), self.coupling_factor, self.tau_w)
170+
_, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
177171
## for spikes, snap to particular states
178-
_v, _w = _post_process(s, _v, _w, v, w, v_reset, w_reset)
172+
_v, _w = _post_process(s, _v, _w, self.v.get(), self.w.get(), self.v_reset, self.w_reset)
179173
v = _v
180174
w = _w
181175

182-
tols = (1. - s) * tols + (s * t) ## update tols
183-
return j, v, w, s, tols
176+
## update time-of-last spike variable(s)
177+
self.tols.set((1. - s) * self.tols.get() + (s * t))
178+
179+
# self.j.set(j) ## j is not getting modified in these dynamics
180+
self.v.set(v)
181+
self.w.set(w)
182+
self.s.set(s)
184183

185-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
186-
@staticmethod
187-
def reset(batch_size, n_units, v0, w0):
188-
restVals = jnp.zeros((batch_size, n_units))
189-
j = restVals # None
190-
v = restVals + v0
191-
w = restVals + w0
192-
s = restVals #+ 0
193-
tols = restVals #+ 0
194-
return j, v, w, s, tols
184+
@compilable
185+
def reset(self):
186+
restVals = jnp.zeros((self.batch_size, self.n_units))
187+
if not self.j.targeted:
188+
self.j.set(restVals)
189+
self.v.set(restVals + self.v0)
190+
self.w.set(restVals + self.w0)
191+
self.s.set(restVals)
192+
self.tols.set(restVals)
195193

196194
@classmethod
197195
def help(cls): ## component help function
@@ -219,8 +217,7 @@ def help(cls): ## component help function
219217
"v_rest": "Resting membrane potential value",
220218
"v_reset": "Reset membrane potential value",
221219
"w_reset": "Reset recover variable value",
222-
"coupling_factor": "Degree to which recovery variable is sensitive to "
223-
"subthreshold voltage fluctuations",
220+
"coupling_factor": "Degree to which recovery variable is sensitive to subthreshold voltage fluctuations",
224221
"v0": "Initial condition for membrane potential/voltage",
225222
"w0": "Initial condition for recovery variable",
226223
"integration_type": "Type of numerical integration to use for the cell dynamics"
@@ -233,7 +230,7 @@ def help(cls): ## component help function
233230
return info
234231

235232
def __repr__(self):
236-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
233+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
237234
maxlen = max(len(c) for c in comps) + 5
238235
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
239236
for c in comps:

0 commit comments

Comments
 (0)