Skip to content

Commit db25673

Browse files
author
Alexander Ororbia
committed
revised quad-lif w/ unit-test
1 parent 3c1ee98 commit db25673

File tree

2 files changed

+220
-165
lines changed

2 files changed

+220
-165
lines changed
Lines changed: 151 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,24 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
12
from jax import numpy as jnp, random, jit, nn
23
from functools import partial
3-
import time, sys
44
from ngclearn.utils import tensorstats
5-
from ngclearn import resolver, Component, Compartment
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
67
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
78
step_euler, step_rk2
8-
## import parent cell class/component
9-
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
10-
11-
@jit
12-
def _update_times(t, s, tols):
13-
"""
14-
Updates time-of-last-spike (tols) variable.
15-
16-
Args:
17-
t: current time (a scalar/int value)
9+
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10+
triangular_estimator,
11+
straight_through_estimator)
1812

19-
s: binary spike vector
13+
from ngcsimlib.compilers.process import transition
14+
#from ngcsimlib.component import Component
15+
from ngcsimlib.compartment import Compartment
2016

21-
tols: current time-of-last-spike variable
22-
23-
Returns:
24-
updated tols variable
25-
"""
26-
_tols = (1. - s) * tols + (s * t)
27-
return _tols
28-
29-
@jit
30-
def _modify_current(j, dt, tau_m): ## electrical current re-scaling co-routine
31-
jScale = tau_m/dt
32-
return j * jScale
17+
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
3318

3419
@jit
3520
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_c, a0): ## raw voltage dynamics
36-
mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
21+
mask = (rfr >= refract_T) * 1. # get refractory mask
3722
## update voltage / membrane potential
3823
dv_dt = ((v_rest - v) * (v - v_c) * a0) + (j * mask)
3924
dv_dt = dv_dt * (1./tau_m)
@@ -44,101 +29,18 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
4429
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_c, a0)
4530
return dv_dt
4631

47-
#@partial(jit, static_argnums=[7,8,9,10,11,12,13,14])
48-
def _run_cell(dt, j, v, v_thr, v_theta, rfr, skey, v_c, a0, tau_m, v_rest,
49-
v_reset, refract_T, integType=0):
50-
"""
51-
Runs quadratic leaky integrator neuronal dynamics
52-
53-
Args:
54-
dt: integration time constant (milliseconds, or ms)
55-
56-
j: electrical current value
57-
58-
v: membrane potential (voltage, in milliVolts or mV) value (at t)
59-
60-
v_thr: base voltage threshold value (in mV)
61-
62-
v_theta: threshold shift (homeostatic) variable (at t)
63-
64-
rfr: refractory variable vector (one per neuronal cell)
65-
66-
skey: PRNG key which, if not None, will trigger a single-spike constraint
67-
(i.e., only one spike permitted to emit per single step of time);
68-
specifically used to randomly sample one of the possible action
69-
potentials to be an emitted spike
70-
71-
v_c: scaling factor for voltage accumulation
72-
73-
a0: critical voltage value
74-
75-
tau_m: cell membrane time constant
76-
77-
v_rest: membrane resting potential (in mV)
78-
79-
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
80-
a neuronal cell's membrane potential will be set to this value
81-
82-
refract_T: (relative) refractory time period (in ms; Default
83-
value is 1 ms)
84-
85-
integType: integer indicating type of integration to use
86-
87-
Returns:
88-
voltage(t+dt), spikes, raw spikes, updated refactory variables
89-
"""
90-
_v_thr = v_theta + v_thr ## calc present voltage threshold
91-
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
92-
## update voltage / membrane potential (v_c ~> 0.8?) (a0 usually <1?)
93-
#_v = v + ((v_rest - v) * (v - v_c) * a0) * (dt/tau_m) + (j * mask)
94-
v_params = (j, rfr, tau_m, refract_T, v_rest, v_c, a0)
95-
if integType == 1:
96-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
97-
else:
98-
_, _v = step_euler(0., v, _dfv, dt, v_params)
99-
## obtain action potentials
100-
s = (_v > _v_thr).astype(jnp.float32)
101-
## update refractory variables
102-
_rfr = (rfr + dt) * (1. - s)
103-
## perform hyper-polarization of neuronal cells
104-
_v = _v * (1. - s) + s * v_reset
105-
106-
raw_s = s + 0 ## preserve un-altered spikes
107-
############################################################################
108-
## this is a spike post-processing step
109-
if skey is not None: ## FIXME: this would not work for mini-batches!!!!!!!
110-
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32)
111-
rS = random.choice(skey, s.shape[1], p=jnp.squeeze(s))
112-
rS = nn.one_hot(rS, num_classes=s.shape[1], dtype=jnp.float32)
113-
s = s * (1. - m_switch) + rS * m_switch
114-
############################################################################
115-
return _v, s, raw_s, _rfr
116-
117-
@partial(jit, static_argnums=[3,4])
32+
#@partial(jit, static_argnums=[3, 4])
11833
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
119-
"""
120-
Runs homeostatic threshold update dynamics one step.
121-
122-
Args:
123-
dt: integration time constant (milliseconds, or ms)
124-
125-
v_theta: current value of homeostatic threshold variable
126-
127-
s: current spikes (at t)
128-
129-
tau_theta: homeostatic threshold time constant
130-
131-
theta_plus: physical increment to be applied to any threshold value if
132-
a spike was emitted
133-
134-
Returns:
135-
updated homeostatic threshold variable
136-
"""
34+
### Runs homeostatic threshold update dynamics one step (via Euler integration).
35+
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
36+
#theta_plus = 0.05
37+
#_V_theta = V_theta * theta_decay + S * theta_plus
13738
theta_decay = jnp.exp(-dt/tau_theta)
13839
_v_theta = v_theta * theta_decay + s * theta_plus
40+
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
13941
return _v_theta
14042

141-
class QuadLIFCell(LIFCell): ## quadratic (leaky) LIF cell; inherits from LIFCell
43+
class QuadLIFCell(LIFCell): ## quadratic integrate-and-fire cell
14244
"""
14345
A spiking cell based on quadratic leaky integrate-and-fire (LIF) neuronal
14446
dynamics. Note that QuadLIFCell is a child of LIFCell and inherits its
@@ -184,9 +86,9 @@ class QuadLIFCell(LIFCell): ## quadratic (leaky) LIF cell; inherits from LIFCell
18486
18587
v_scale: scaling factor for voltage accumulation (v_c)
18688
187-
critical_V: critical voltage value (a0)
89+
critical_v: critical voltage value (in mV) (i.e., variable name - a0)
18890
189-
tau_theta: homeostatic threshold time constant
91+
tau_theta: homeostatic threshold time constant
19092
19193
theta_plus: physical increment to be applied to any threshold value if
19294
a spike was emitted
@@ -198,58 +100,138 @@ class QuadLIFCell(LIFCell): ## quadratic (leaky) LIF cell; inherits from LIFCell
198100
a single spike will be permitted to emit per step -- this means that
199101
if > 1 spikes emitted, a single action potential will be randomly
200102
sampled from the non-zero spikes detected
201-
"""
202-
203-
# Define Functions
204-
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
205-
v_reset=60., v_scale=-41.6, critical_V=1., tau_theta=1e7,
206-
theta_plus=0.05, refract_time=5., thr_jitter=0., one_spike=False,
207-
integration_type="euler", **kwargs):
208-
super().__init__(name, n_units, tau_m, resist_m, thr, v_rest, v_reset,
209-
1., tau_theta, theta_plus, refract_time, thr_jitter,
210-
one_spike, integration_type, **kwargs)
103+
""" ## batch_size arg?
104+
105+
@deprecate_args(thr_jitter=None)
106+
def __init__(
107+
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_scale=-41.6, critical_v=1.,
108+
tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler",
109+
surrgoate_type="straight_through", lower_clamp_voltage=True, **kwargs
110+
):
111+
super().__init__(
112+
name, n_units, tau_m, resist_m, thr, v_rest, v_reset, 1., tau_theta, theta_plus, refract_time,
113+
one_spike, integration_type, surrgoate_type, lower_clamp_voltage, **kwargs
114+
)
211115
## only two distinct additional constants distinguish the Quad-LIF cell
212116
self.v_c = v_scale
213-
self.a0 = critical_V
117+
self.a0 = critical_v
214118

119+
@transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
215120
@staticmethod
216-
def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, refract_T, thr,
217-
tau_theta, theta_plus, one_spike, v_c, a0, intgFlag, key,
218-
j, v, s, rfr, thr_theta, tols):
219-
## Note: this runs quadratic LIF neuronal dynamics but constrained to be
220-
## similar to the general form of LIF dynamics
121+
def advance_state(
122+
t, dt, tau_m, resist_m, v_rest, v_reset, v_c, a0, refract_T, thr, tau_theta, theta_plus,
123+
one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
124+
):
221125
skey = None ## this is an empty dkey if single_spike mode turned off
222-
if one_spike: ## old code ~> if self.one_spike is False:
223-
key, *subkeys = random.split(key, 2)
224-
skey = subkeys[0]
126+
if one_spike:
127+
key, skey = random.split(key, 2)
225128
## run one integration step for neuronal dynamics
226-
j = j * R_m
227-
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
228-
v_c, a0, tau_m, v_rest, v_reset,
229-
refract_T, intgFlag)
129+
j = j * resist_m
130+
############################################################################
131+
### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
132+
_v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
133+
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
134+
## update voltage / membrane potential
135+
v_params = (j, rfr, tau_m, refract_T, v_rest, v_c, a0)
136+
if intgFlag == 1:
137+
_, _v = step_rk2(0., v, _dfv, dt, v_params)
138+
else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
139+
_, _v = step_euler(0., v, _dfv, dt, v_params)
140+
## obtain action potentials/spikes
141+
s = (_v > _v_thr) * 1.
142+
## update refractory variables
143+
_rfr = (rfr + dt) * (1. - s)
144+
## perform hyper-polarization of neuronal cells
145+
_v = _v * (1. - s) + s * v_reset
146+
147+
raw_s = s + 0 ## preserve un-altered spikes
148+
############################################################################
149+
## this is a spike post-processing step
150+
if skey is not None:
151+
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
152+
rS = s * random.uniform(skey, s.shape)
153+
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
154+
dtype=jnp.float32)
155+
s = s * (1. - m_switch) + rS * m_switch
156+
############################################################################
157+
raw_spikes = raw_s
158+
v = _v
159+
rfr = _rfr
160+
161+
surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta)
230162
if tau_theta > 0.:
231163
## run one integration step for threshold dynamics
232-
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta,
233-
theta_plus)
164+
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
234165
## update tols
235-
tols = _update_times(t, s, tols)
236-
return v, s, raw_spikes, rfr, thr_theta, tols, key
237-
238-
@resolver(_advance_state)
239-
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key):
240-
self.v.set(v)
241-
self.s.set(s)
242-
self.s_raw.set(s_raw)
243-
self.rfr.set(rfr)
244-
self.thr_theta.set(thr_theta)
245-
self.tols.set(tols)
246-
self.key.set(key)
166+
tols = (1. - s) * tols + (s * t)
167+
if lower_clamp_voltage: ## ensure voltage never < v_rest
168+
v = jnp.maximum(v, v_rest)
169+
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
170+
171+
@transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])
172+
@staticmethod
173+
def reset(batch_size, n_units, v_rest, refract_T):
174+
restVals = jnp.zeros((batch_size, n_units))
175+
j = restVals #+ 0
176+
v = restVals + v_rest
177+
s = restVals #+ 0
178+
s_raw = restVals
179+
rfr = restVals + refract_T
180+
#thr_theta = restVals ## do not reset thr_theta
181+
tols = restVals #+ 0
182+
surrogate = restVals + 1.
183+
return j, v, s, s_raw, rfr, tols, surrogate
184+
185+
def save(self, directory, **kwargs):
186+
## do a protected save of constants, depending on whether they are floats or arrays
187+
tau_m = (self.tau_m if isinstance(self.tau_m, float)
188+
else jnp.asarray([[self.tau_m * 1.]]))
189+
thr = (self.thr if isinstance(self.thr, float)
190+
else jnp.asarray([[self.thr * 1.]]))
191+
v_rest = (self.v_rest if isinstance(self.v_rest, float)
192+
else jnp.asarray([[self.v_rest * 1.]]))
193+
v_reset = (self.v_reset if isinstance(self.v_reset, float)
194+
else jnp.asarray([[self.v_reset * 1.]]))
195+
v_decay = (self.v_decay if isinstance(self.v_decay, float)
196+
else jnp.asarray([[self.v_decay * 1.]]))
197+
resist_m = (self.resist_m if isinstance(self.resist_m, float)
198+
else jnp.asarray([[self.resist_m * 1.]]))
199+
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
200+
else jnp.asarray([[self.tau_theta * 1.]]))
201+
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
202+
else jnp.asarray([[self.theta_plus * 1.]]))
203+
204+
file_name = directory + "/" + self.name + ".npz"
205+
jnp.savez(file_name,
206+
threshold_theta=self.thr_theta.value,
207+
tau_m=tau_m, thr=thr, v_rest=v_rest,
208+
v_reset=v_reset, v_decay=v_decay,
209+
resist_m=resist_m, tau_theta=tau_theta,
210+
theta_plus=theta_plus,
211+
key=self.key.value)
212+
213+
def load(self, directory, seeded=False, **kwargs):
214+
file_name = directory + "/" + self.name + ".npz"
215+
data = jnp.load(file_name)
216+
self.thr_theta.set(data['threshold_theta'])
217+
## constants loaded in
218+
self.tau_m = data['tau_m']
219+
self.thr = data['thr']
220+
self.v_rest = data['v_rest']
221+
self.v_reset = data['v_reset']
222+
self.v_decay = data['v_decay']
223+
self.resist_m = data['resist_m']
224+
self.tau_theta = data['tau_theta']
225+
self.theta_plus = data['theta_plus']
226+
227+
if seeded:
228+
self.key.set(data['key'])
247229

248230
@classmethod
249231
def help(cls): ## component help function
250232
properties = {
251-
"cell_type": "QuadLIFCell - evolves neurons according to quadratic "
252-
"leaky integrate-and-fire spiking dynamics."
233+
"cell_type": "LIFCell - evolves neurons according to leaky integrate-"
234+
"and-fire spiking dynamics."
253235
}
254236
compartment_props = {
255237
"inputs":
@@ -258,6 +240,7 @@ def help(cls): ## component help function
258240
{"v": "Membrane potential/voltage at time t",
259241
"rfr": "Current state of (relative) refractory variable",
260242
"thr": "Current state of voltage threshold at time t",
243+
"thr_theta": "Current state of homeostatic adaptive threshold at time t",
261244
"key": "JAX PRNG key"},
262245
"outputs":
263246
{"s": "Emitted spikes/pulses at time t",
@@ -271,14 +254,18 @@ def help(cls): ## component help function
271254
"v_rest": "Resting membrane potential value",
272255
"v_reset": "Reset membrane potential value",
273256
"v_decay": "Voltage leak/decay factor",
274-
"v_scale": "Scaling factor for voltage accumulation",
275-
"critical_V": "Critical voltage value",
276257
"tau_theta": "Threshold/homoestatic increment time constant",
277-
"theta_plus": "Amount to increment threshold by upon occurrence of spike",
258+
"theta_plus": "Amount to increment threshold by upon occurrence "
259+
"of spike",
278260
"refract_time": "Length of relative refractory period (ms)",
279-
"thr_jitter": "Scale of random uniform noise to apply to initial condition of threshold",
280-
"one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
281-
"integration_type": "Type of numerical integration to use for the cell dynamics"
261+
"one_spike": "Should only one spike be sampled/allowed to emit at "
262+
"any given time step?",
263+
"integration_type": "Type of numerical integration to use for the "
264+
"cell dynamics",
265+
"surrgoate_type": "Type of surrogate function to use approximate "
266+
"derivative of spike w.r.t. voltage/current",
267+
"lower_bound_clamp": "Should voltage be lower bounded to be never "
268+
"be below `v_rest`"
282269
}
283270
info = {cls.__name__: properties,
284271
"compartments": compartment_props,
@@ -301,8 +288,7 @@ def __repr__(self):
301288
return lines
302289

303290
if __name__ == '__main__':
304-
# NOTE: VN: currently error in init function
305291
from ngcsimlib.context import Context
306292
with Context("Bar") as bar:
307-
X = QuadLIFCell("X", 1, 10.)
293+
X = QuadLIFCell("X", 9, 0.0004, 3)
308294
print(X)

0 commit comments

Comments
 (0)