Skip to content

Commit 26e27c4

Browse files
author
Alexander Ororbia
committed
revised raf-cell w/ unit test; fixed typos/mistakes in all spiking cells
1 parent eaafb64 commit 26e27c4

File tree

10 files changed

+131
-72
lines changed

10 files changed

+131
-72
lines changed

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
7575
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
7676
7777
:Note: setting the integration type to the midpoint method will
78-
increase the accuray of the estimate of the cell's evolution
78+
increase the accuracy of the estimate of the cell's evolution
7979
at an increase in computational cost (and simulation time)
8080
81-
surrgoate_type: type of surrogate function to use for approximating a
81+
surrogate_type: type of surrogate function to use for approximating a
8282
partial derivative of this cell's spikes w.r.t. its voltage/current
8383
(default: "straight_through")
8484
@@ -93,7 +93,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
9393
@deprecate_args(thr_jitter=None)
9494
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
9595
v_reset=-60., refract_time=0., integration_type="euler",
96-
surrgoate_type="straight_through", lower_clamp_voltage=True,
96+
surrogate_type="straight_through", lower_clamp_voltage=True,
9797
**kwargs):
9898
super().__init__(name, **kwargs)
9999

@@ -118,9 +118,9 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
118118
self.n_units = n_units
119119

120120
## set up surrogate function for spike emission
121-
if surrgoate_type == "arctan":
121+
if surrogate_type == "arctan":
122122
self.spike_fx, self.d_spike_fx = arctan_estimator()
123-
elif surrgoate_type == "triangular":
123+
elif surrogate_type == "triangular":
124124
self.spike_fx, self.d_spike_fx = triangular_estimator()
125125
else: ## default: straight_through
126126
self.spike_fx, self.d_spike_fx = straight_through_estimator()

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
100100
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
101101
102102
:Note: setting the integration type to the midpoint method will
103-
increase the accuray of the estimate of the cell's evolution
103+
increase the accuracy of the estimate of the cell's evolution
104104
at an increase in computational cost (and simulation time)
105105
106-
surrgoate_type: type of surrogate function to use for approximating a
106+
surrogate_type: type of surrogate function to use for approximating a
107107
partial derivative of this cell's spikes w.r.t. its voltage/current
108108
(default: "straight_through")
109109
@@ -120,7 +120,7 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
120120
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
121121
v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05,
122122
refract_time=5., one_spike=False, integration_type="euler",
123-
surrgoate_type="straight_through", lower_clamp_voltage=True,
123+
surrogate_type="straight_through", lower_clamp_voltage=True,
124124
**kwargs):
125125
super().__init__(name, **kwargs)
126126

@@ -150,11 +150,11 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
150150
self.n_units = n_units
151151

152152
## set up surrogate function for spike emission
153-
if surrgoate_type == "secant_lif":
153+
if surrogate_type == "secant_lif":
154154
self.spike_fx, self.d_spike_fx = secant_lif_estimator()
155-
elif surrgoate_type == "arctan":
155+
elif surrogate_type == "arctan":
156156
self.spike_fx, self.d_spike_fx = arctan_estimator()
157-
elif surrgoate_type == "triangular":
157+
elif surrogate_type == "triangular":
158158
self.spike_fx, self.d_spike_fx = triangular_estimator()
159159
else: ## default: straight_through
160160
self.spike_fx, self.d_spike_fx = straight_through_estimator()

ngclearn/components/neurons/spiking/RAFCell.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,15 @@
1-
from jax import numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
31
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, random, jit, nn
3+
from functools import partial
44
from ngclearn.utils import tensorstats
55
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
67
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
78
step_euler, step_rk2
89

9-
@jit
10-
def _update_times(t, s, tols):
11-
"""
12-
Updates time-of-last-spike (tols) variable.
13-
14-
Args:
15-
t: current time (a scalar/int value)
16-
17-
s: binary spike vector
18-
19-
tols: current time-of-last-spike variable
20-
21-
Returns:
22-
updated tols variable
23-
"""
24-
_tols = (1. - s) * tols + (s * t)
25-
return _tols
10+
from ngcsimlib.compilers.process import transition
11+
#from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
2613

2714
@jit
2815
def _dfv_internal(j, v, w, tau_m, omega, b): ## "voltage" dynamics
@@ -48,11 +35,6 @@ def _dfw(t, w, params): ## angular driver dynamics wrapper
4835
dv_dt = _dfw_internal(j, v, w, tau_w, omega, b)
4936
return dv_dt
5037

51-
@jit
52-
def _emit_spike(v, v_thr):
53-
s = (v > v_thr).astype(jnp.float32)
54-
return s
55-
5638
class RAFCell(JaxComponent):
5739
"""
5840
The resonate-and-fire (RAF) neuronal cell
@@ -112,14 +94,15 @@ class RAFCell(JaxComponent):
11294
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
11395
11496
:Note: setting the integration type to the midpoint method will
115-
increase the accuray of the estimate of the cell's evolution
97+
increase the accuracy of the estimate of the cell's evolution
11698
at an increase in computational cost (and simulation time)
11799
"""
118100

119101
@deprecate_args(resist_m="resist_v", tau_m="tau_v")
120-
def __init__(self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10.,
121-
b=-1., v_reset=1., w_reset=0., v0=0., w0=0., resist_v=1.,
122-
integration_type="euler", batch_size=1, **kwargs):
102+
def __init__(
103+
self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., b=-1., v_reset=0., w_reset=0., v0=0., w0=0.,
104+
resist_v=1., integration_type="euler", batch_size=1, **kwargs
105+
):
123106
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1.
124107
super().__init__(name, **kwargs)
125108

@@ -150,11 +133,13 @@ def __init__(self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10.,
150133
self.v = Compartment(restVals + self.v0, display_name="Voltage", units="mV")
151134
self.w = Compartment(restVals + self.w0, display_name="Angular-Driver")
152135
self.s = Compartment(restVals, display_name="Spikes")
153-
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
154-
units="ms") ## time-of-last-spike
136+
self.tols = Compartment(
137+
restVals, display_name="Time-of-Last-Spike", units="ms"
138+
) ## time-of-last-spike
155139

140+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
156141
@staticmethod
157-
def _advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
142+
def advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
158143
v_reset, w_reset, intgFlag, j, v, w, tols):
159144
## continue with centered dynamics
160145
j_ = j * resist_v
@@ -170,24 +155,17 @@ def _advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
170155
_, _w = step_euler(0., w, _dfw, dt, w_params)
171156
v_params = (j_, _w, tau_v, omega, b)
172157
_, _v = step_euler(0., v, _dfv, dt, v_params)
173-
s = _emit_spike(_v, thr)
158+
s = (_v > thr) * 1. ## emit spikes/pulses
174159
## hyperpolarize/reset/snap variables
175160
w = _w * (1. - s) + s * w_reset
176161
v = _v * (1. - s) + s * v_reset
177162

178-
tols = _update_times(t, s, tols)
163+
tols = (1. - s) * tols + (s * t) ## update times-of-last-spike(s)
179164
return j, v, w, s, tols
180165

181-
@resolver(_advance_state)
182-
def advance_state(self, j, v, w, s, tols):
183-
self.j.set(j)
184-
self.w.set(w)
185-
self.v.set(v)
186-
self.s.set(s)
187-
self.tols.set(tols)
188-
166+
@transition(output_compartments=["j", "v", "w", "s", "tols"])
189167
@staticmethod
190-
def _reset(batch_size, n_units, v0, w0):
168+
def reset(batch_size, n_units, v0, w0):
191169
restVals = jnp.zeros((batch_size, n_units))
192170
j = restVals # None
193171
v = restVals + v0
@@ -196,14 +174,6 @@ def _reset(batch_size, n_units, v0, w0):
196174
tols = restVals #+ 0
197175
return j, v, w, s, tols
198176

199-
@resolver(_reset)
200-
def reset(self, j, v, w, s, tols):
201-
self.j.set(j)
202-
self.v.set(v)
203-
self.w.set(w)
204-
self.s.set(s)
205-
self.tols.set(tols)
206-
207177
@classmethod
208178
def help(cls): ## component help function
209179
properties = {

ngclearn/components/neurons/spiking/WTASCell.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ class WTASCell(JaxComponent): ## winner-take-all spiking cell
102102
"""
103103

104104
# Define Functions
105-
def __init__(self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002,
106-
refract_time=0., thr_jitter=0.05, **kwargs):
105+
def __init__(
106+
self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002, refract_time=0., thr_jitter=0.05,
107+
**kwargs
108+
):
107109
super().__init__(name, **kwargs)
108110

109111
## membrane parameter setup (affects ODE integration)

ngclearn/components/neurons/spiking/adExCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class AdExCell(JaxComponent):
9393
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
9494
9595
:Note: setting the integration type to the midpoint method will
96-
increase the accuray of the estimate of the cell's evolution
96+
increase the accuracy of the estimate of the cell's evolution
9797
at an increase in computational cost (and simulation time)
9898
"""
9999

ngclearn/components/neurons/spiking/fitzhughNagumoCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class FitzhughNagumoCell(JaxComponent):
129129
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
130130
131131
:Note: setting the integration type to the midpoint method will
132-
increase the accuray of the estimate of the cell's evolution
132+
increase the accuracy of the estimate of the cell's evolution
133133
at an increase in computational cost (and simulation time)
134134
"""
135135

ngclearn/components/neurons/spiking/izhikevichCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class IzhikevichCell(JaxComponent): ## Izhikevich neuronal cell
159159
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
160160
161161
:Note: setting the integration type to the midpoint method will
162-
increase the accuray of the estimate of the cell's evolution
162+
increase the accuracy of the estimate of the cell's evolution
163163
at an increase in computational cost (and simulation time)
164164
"""
165165

ngclearn/components/neurons/spiking/quadLIFCell.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,34 @@ class QuadLIFCell(LIFCell): ## quadratic integrate-and-fire cell
100100
a single spike will be permitted to emit per step -- this means that
101101
if > 1 spikes emitted, a single action potential will be randomly
102102
sampled from the non-zero spikes detected
103+
104+
integration_type: type of integration to use for this cell's dynamics;
105+
current supported forms include "euler" (Euler/RK-1 integration)
106+
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
107+
108+
:Note: setting the integration type to the midpoint method will
109+
increase the accuracy of the estimate of the cell's evolution
110+
at an increase in computational cost (and simulation time)
111+
112+
surrogate_type: type of surrogate function to use for approximating a
113+
partial derivative of this cell's spikes w.r.t. its voltage/current
114+
(default: "straight_through")
115+
116+
:Note: surrogate options available include: "straight_through"
117+
(straight-through estimator), "triangular" (triangular estimator),
118+
"arctan" (arc-tangent estimator), and "secant_lif" (the
119+
LIF-specialized secant estimator)
103120
""" ## batch_size arg?
104121

105122
@deprecate_args(thr_jitter=None, critical_v="critical_V")
106123
def __init__(
107124
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_scale=-41.6, critical_v=1.,
108125
tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler",
109-
surrgoate_type="straight_through", lower_clamp_voltage=True, **kwargs
126+
surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs
110127
):
111128
super().__init__(
112129
name, n_units, tau_m, resist_m, thr, v_rest, v_reset, 1., tau_theta, theta_plus, refract_time,
113-
one_spike, integration_type, surrgoate_type, lower_clamp_voltage, **kwargs
130+
one_spike, integration_type, surrogate_type, lower_clamp_voltage, **kwargs
114131
)
115132
## only two distinct additional constants distinguish the Quad-LIF cell
116133
self.v_c = v_scale

ngclearn/components/neurons/spiking/sLIFCell.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell
113113
"""
114114

115115
# Define Functions
116-
def __init__(self, name, n_units, tau_m, resist_m, thr, resist_inh=0.,
117-
thr_persist=False, thr_gain=0.0, thr_leak=0.0, rho_b=0.,
118-
refract_time=0., sticky_spikes=False, thr_jitter=0.05,
119-
batch_size=1, **kwargs):
116+
def __init__(
117+
self, name, n_units, tau_m, resist_m, thr, resist_inh=0., thr_persist=False, thr_gain=0.0, thr_leak=0.0,
118+
rho_b=0., refract_time=0., sticky_spikes=False, thr_jitter=0.05, batch_size=1, **kwargs
119+
):
120120
super().__init__(name, **kwargs)
121121

122122
## membrane parameter setup (affects ODE integration)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
5+
np.random.seed(42)
6+
from ngclearn.components import RAFCell
7+
from ngcsimlib.compilers import compile_command, wrap_command
8+
from numpy.testing import assert_array_equal
9+
10+
from ngcsimlib.compilers.process import Process, transition
11+
from ngcsimlib.component import Component
12+
from ngcsimlib.compartment import Compartment
13+
from ngcsimlib.context import Context
14+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
15+
16+
17+
def test_RAFCell1():
18+
name = "raf_ctx"
19+
## create seeding keys
20+
dkey = random.PRNGKey(1234)
21+
dkey, *subkeys = random.split(dkey, 6)
22+
dt = 1. # ms
23+
# ---- build a simple Poisson cell system ----
24+
with Context(name) as ctx:
25+
a = RAFCell(
26+
name="a", n_units=1, tau_v=20., resist_v=1., key=subkeys[0]
27+
)
28+
29+
#"""
30+
advance_process = (Process()
31+
>> a.advance_state)
32+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
34+
35+
reset_process = (Process()
36+
>> a.reset)
37+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38+
#"""
39+
40+
"""
41+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45+
"""
46+
47+
## set up non-compiled utility commands
48+
@Context.dynamicCommand
49+
def clamp(x):
50+
a.j.set(x)
51+
52+
## input spike train
53+
x_seq = jnp.asarray([[0., 1., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32)
54+
## desired output/epsp pulses
55+
y_seq = jnp.asarray([[0., 0., 0., 1., 0., 0., 0., 0., 1.]], dtype=jnp.float32)
56+
57+
outs = []
58+
ctx.reset()
59+
for ts in range(x_seq.shape[1]):
60+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
61+
ctx.clamp(x_t)
62+
ctx.run(t=ts * 1., dt=dt)
63+
outs.append(a.s.value)
64+
outs = jnp.concatenate(outs, axis=1)
65+
#print(outs)
66+
67+
## output should equal input
68+
assert_array_equal(outs, y_seq)
69+
70+
#test_RAFCell1()

0 commit comments

Comments
 (0)