Skip to content

Commit f4e661c

Browse files
author
Alexander Ororbia
committed
revised if-cell w/ unit-test
1 parent 68ccc09 commit f4e661c

File tree

2 files changed

+84
-42
lines changed

2 files changed

+84
-42
lines changed

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,19 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
12
from jax import numpy as jnp, random, jit, nn
3+
from functools import partial
24
from ngclearn.utils import tensorstats
35
from ngcsimlib.deprecators import deprecate_args
4-
from ngclearn import resolver, Component, Compartment
5-
from ngclearn.components.jaxComponent import JaxComponent
6+
from ngcsimlib.logger import info, warn
67
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
78
step_euler, step_rk2
8-
from ngclearn.utils.surrogate_fx import (arctan_estimator,
9+
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
910
triangular_estimator,
1011
straight_through_estimator)
1112

12-
@jit
13-
def _update_times(t, s, tols):
14-
"""
15-
Updates time-of-last-spike (tols) variable.
16-
17-
Args:
18-
t: current time (a scalar/int value)
13+
from ngcsimlib.compilers.process import transition
14+
#from ngcsimlib.component import Component
15+
from ngcsimlib.compartment import Compartment
1916

20-
s: binary spike vector
21-
22-
tols: current time-of-last-spike variable
23-
24-
Returns:
25-
updated tols variable
26-
"""
27-
_tols = (1. - s) * tols + (s * t)
28-
return _tols
2917

3018
@jit
3119
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
@@ -166,32 +154,26 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
166154
units="ms") ## time-of-last-spike
167155
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
168156

157+
@transition(output_compartments=["v", "s", "rfr", "tols", "key", "surrogate"])
169158
@staticmethod
170-
def _advance_state(t, dt, tau_m, resist_m, v_rest, v_reset, refract_T,
171-
thr, lower_clamp_voltage, intgFlag, d_spike_fx, key,
172-
j, v, rfr, tols):
159+
def advance_state(
160+
t, dt, tau_m, resist_m, v_rest, v_reset, refract_T, thr, lower_clamp_voltage, intgFlag, d_spike_fx, key,
161+
j, v, rfr, tols
162+
):
173163
## run one integration step for neuronal dynamics
174164
j = j * resist_m
175165
v, s, rfr = _run_cell(dt, j, v, thr, rfr, tau_m, v_rest, v_reset,
176166
refract_T, intgFlag)
177167
surrogate = d_spike_fx(v, thr)
178168
## update tols
179-
tols = _update_times(t, s, tols)
169+
tols = (1. - s) * tols + (s * t)
180170
if lower_clamp_voltage: ## ensure voltage never < v_rest
181171
v = jnp.maximum(v, v_rest)
182172
return v, s, rfr, tols, key, surrogate
183173

184-
@resolver(_advance_state)
185-
def advance_state(self, v, s, rfr, tols, key, surrogate):
186-
self.v.set(v)
187-
self.s.set(s)
188-
self.rfr.set(rfr)
189-
self.tols.set(tols)
190-
self.key.set(key)
191-
self.surrogate.set(surrogate)
192-
174+
@transition(output_compartments=["j", "v", "s", "rfr", "tols", "surrogate"])
193175
@staticmethod
194-
def _reset(batch_size, n_units, v_rest, refract_T):
176+
def reset(batch_size, n_units, v_rest, refract_T):
195177
restVals = jnp.zeros((batch_size, n_units))
196178
j = restVals #+ 0
197179
v = restVals + v_rest
@@ -201,15 +183,6 @@ def _reset(batch_size, n_units, v_rest, refract_T):
201183
surrogate = restVals + 1.
202184
return j, v, s, rfr, tols, surrogate
203185

204-
@resolver(_reset)
205-
def reset(self, j, v, s, rfr, tols, surrogate):
206-
self.j.set(j)
207-
self.v.set(v)
208-
self.s.set(s)
209-
self.rfr.set(rfr)
210-
self.tols.set(tols)
211-
self.surrogate.set(surrogate)
212-
213186
def save(self, directory, **kwargs):
214187
## do a protected save of constants, depending on whether they are floats or arrays
215188
tau_m = (self.tau_m if isinstance(self.tau_m, float)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import IFCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
def test_IFCell1():
16+
name = "if_ctx"
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
trace_increment = 0.1
22+
# ---- build a simple Poisson cell system ----
23+
with Context(name) as ctx:
24+
a = IFCell(
25+
name="a", n_units=1, tau_m=5., resist_m=10., key=subkeys[0]
26+
)
27+
28+
#"""
29+
advance_process = (Process()
30+
>> a.advance_state)
31+
#ctx.wrap_and_add_command(advance_process.pure, name="run")
32+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
33+
34+
reset_process = (Process()
35+
>> a.reset)
36+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
37+
#"""
38+
39+
"""
40+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
41+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
42+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
43+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
44+
"""
45+
46+
## set up non-compiled utility commands
47+
@Context.dynamicCommand
48+
def clamp(x):
49+
a.j.set(x)
50+
51+
## input spike train
52+
x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32)
53+
## desired output/epsp pulses
54+
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
55+
56+
outs = []
57+
ctx.reset()
58+
for ts in range(x_seq.shape[1]):
59+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
60+
ctx.clamp(x_t)
61+
ctx.run(t=ts * 1., dt=dt)
62+
outs.append(a.s.value)
63+
outs = jnp.concatenate(outs, axis=1)
64+
print(outs)
65+
66+
## output should equal input
67+
assert_array_equal(outs, y_seq)
68+
69+
#test_IFCell1()

0 commit comments

Comments
 (0)