Skip to content

Commit 4013bc0

Browse files
author
Alexander Ororbia
committed
refactored adex/test passed; minor cleanup in lif, raf, and wtas cells
1 parent ba3fb6d commit 4013bc0

File tree

5 files changed

+62
-79
lines changed

5 files changed

+62
-79
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def advance_state(self, dt, t):
200200
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
201201
self.thr_theta.set(thr_theta)
202202

203-
## update tols
203+
## update time-of-last spike variable(s)
204204
self.tols.set((1. - s) * self.tols.get() + (s * t))
205205

206206
if self.v_min is not None: ## ensures voltage never < v_rest

ngclearn/components/neurons/spiking/RAFCell.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ def __init__(
139139
) ## time-of-last-spike
140140

141141
@compilable
142-
def advance_state(
143-
self, t, dt
144-
):
142+
def advance_state(self, t, dt):
145143
## continue with centered dynamics
146144
j_ = self.j.get() * self.resist_v
147145
if self.intgFlag == 1: ## RK-2/midpoint

ngclearn/components/neurons/spiking/WTASCell.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,8 @@ def __init__(
8787
self.rfr = Compartment(restVals + self.refract_T)
8888
self.tols = Compartment(restVals) ## time-of-last-spike
8989

90-
# @transition(output_compartments=["v", "s", "thr", "rfr", "tols"])
91-
# @staticmethod
9290
@compilable
93-
def advance_state(
94-
self, t, dt #, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols
95-
):
91+
def advance_state(self, t, dt):
9692
mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period
9793
v = (self.j.get() * self.R_m) * mask
9894
vp = softmax(v) # convert to Categorical (spike) probabilities
@@ -111,8 +107,6 @@ def advance_state(
111107
self.thr.set(thr)
112108
self.rfr.set(rfr)
113109

114-
# @transition(output_compartments=["j", "v", "s", "rfr", "tols"])
115-
# @staticmethod
116110
@compilable
117111
def reset(self):
118112
restVals = jnp.zeros((self.batch_size, self.n_units))

ngclearn/components/neurons/spiking/adExCell.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from jax import numpy as jnp, random, jit, nn
33
from functools import partial
44
from ngclearn.utils import tensorstats
5-
from ngcsimlib.deprecators import deprecate_args
5+
from ngcsimlib import deprecate_args
66
from ngcsimlib.logger import info, warn
77
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
88
step_euler, step_rk2
9-
from ngcsimlib.compilers.process import transition
10-
#from ngcsimlib.component import Component
9+
10+
from ngcsimlib.parser import compilable
1111
from ngcsimlib.compartment import Compartment
1212

1313
@jit
@@ -97,7 +97,7 @@ class AdExCell(JaxComponent):
9797
at an increase in computational cost (and simulation time)
9898
"""
9999

100-
@deprecate_args(v_thr="thr")
100+
#@deprecate_args(v_thr="thr")
101101
def __init__(
102102
self, name, n_units, tau_m=15., resist_m=1., tau_w=400., v_sharpness=2., intrinsic_mem_thr=-55., thr=5.,
103103
v_rest=-72., v_reset=-75., a=0.1, b=0.75, v0=-70., w0=0., integration_type="euler", batch_size=1, **kwargs
@@ -136,39 +136,40 @@ def __init__(
136136
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
137137
units="ms") ## time-of-last-spike
138138

139-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
140-
@staticmethod
141-
def advance_state(
142-
t, dt, tau_m, R_m, tau_w, thr, a, b, sharpV, vT, v_rest, v_reset, intgFlag, j, v, w, tols
143-
):
144-
if intgFlag == 1: ## RK-2/midpoint
145-
v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
146-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
147-
w_params = (j, v, a, tau_w, v_rest)
148-
_, _w = step_rk2(0., w, _dfw, dt, w_params)
139+
@compilable
140+
def advance_state(self, t, dt):
141+
if self.intgFlag == 1: ## RK-2/midpoint
142+
v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m)
143+
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
144+
w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest)
145+
_, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params)
149146
else: # intgFlag == 0 (default -- Euler)
150-
v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
151-
_, _v = step_euler(0., v, _dfv, dt, v_params)
152-
w_params = (j, v, a, tau_w, v_rest)
153-
_, _w = step_euler(0., w, _dfw, dt, w_params)
154-
s = (_v > thr) * 1. ## emit spikes/pulses
147+
v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m)
148+
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
149+
w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest)
150+
_, _w = step_euler(0., self.w.get(), _dfw, dt, w_params)
151+
s = (_v > self.thr) * 1. ## emit spikes/pulses
155152
## hyperpolarize/reset/snap variables
156-
v = _v * (1. - s) + s * v_reset
157-
w = _w * (1. - s) + s * (_w + b)
158-
159-
tols = (1. - s) * tols + (s * t) ## update time-of-last spike variable(s)
160-
return j, v, w, s, tols
161-
162-
@transition(output_compartments=["j", "v", "w", "s", "tols"])
163-
@staticmethod
164-
def reset(batch_size, n_units, v0, w0):
165-
restVals = jnp.zeros((batch_size, n_units))
166-
j = restVals # None
167-
v = restVals + v0
168-
w = restVals + w0
169-
s = restVals #+ 0
170-
tols = restVals #+ 0
171-
return j, v, w, s, tols
153+
v = _v * (1. - s) + s * self.v_reset
154+
w = _w * (1. - s) + s * (_w + self.b)
155+
156+
## update time-of-last spike variable(s)
157+
self.tols.set((1. - s) * self.tols.get() + (s * t))
158+
159+
#self.j.set(j) ## j is not getting modified in these dynamics
160+
self.v.set(v)
161+
self.w.set(w)
162+
self.s.set(s)
163+
164+
@compilable
165+
def reset(self):
166+
restVals = jnp.zeros((self.batch_size, self.n_units))
167+
if not self.j.targeted:
168+
self.j.set(restVals)
169+
self.v.set(restVals + self.v0)
170+
self.w.set(restVals + self.w0)
171+
self.s.set(restVals)
172+
self.tols.set(restVals)
172173

173174
@classmethod
174175
def help(cls): ## component help function
Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
from jax import numpy as jnp, random, jit
22
from ngcsimlib.context import Context
33
import numpy as np
4-
54
np.random.seed(42)
6-
from ngclearn.components import AdExCell
7-
from ngcsimlib.compilers import compile_command, wrap_command
8-
from numpy.testing import assert_array_equal
95

10-
from ngcsimlib.compilers.process import Process, transition
11-
from ngcsimlib.component import Component
12-
from ngcsimlib.compartment import Compartment
13-
from ngcsimlib.context import Context
14-
from ngcsimlib.utils.compartment import Get_Compartment_Batch
6+
from ngclearn import Context, MethodProcess
7+
from ngclearn.components.neurons.spiking.adExCell import AdExCell
8+
from numpy.testing import assert_array_equal
159

1610

1711
def test_adExCell1():
@@ -26,45 +20,41 @@ def test_adExCell1():
2620
name="a", n_units=1, tau_m=50., resist_m=30., thr=-66., key=subkeys[0]
2721
)
2822

29-
#"""
30-
advance_process = (Process("advance_proc")
23+
# """
24+
advance_process = (MethodProcess("advance_proc")
3125
>> a.advance_state)
32-
# ctx.wrap_and_add_command(advance_process.pure, name="run")
33-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
26+
# ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3427

35-
reset_process = (Process("reset_proc")
28+
reset_process = (MethodProcess("reset_proc")
3629
>> a.reset)
37-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
38-
#"""
39-
40-
"""
41-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
42-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
43-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
44-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
45-
"""
46-
30+
# ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
31+
# """
4732
## set up non-compiled utility commands
48-
@Context.dynamicCommand
49-
def clamp(x):
50-
a.j.set(x)
33+
# @Context.dynamicCommand
34+
# def clamp(x):
35+
# a.j.set(x)
36+
37+
def clamp(x):
38+
a.j.set(x)
5139

5240
## input spike train
5341
x_seq = jnp.ones((1, 10))
5442
## desired output/epsp pulses
5543
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32)
5644

5745
outs = []
58-
ctx.reset()
46+
reset_process.run() # ctx.reset()
5947
for ts in range(x_seq.shape[1]):
6048
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
61-
ctx.clamp(x_t)
62-
ctx.run(t=ts * 1., dt=dt)
63-
outs.append(a.s.value)
49+
clamp(x_t) # ctx.clamp(x_t)
50+
advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
51+
outs.append(a.s.get())
52+
6453
outs = jnp.concatenate(outs, axis=1)
65-
#print(outs)
54+
# print(outs)
55+
# print(y_seq)
6656

6757
## output should equal input
6858
assert_array_equal(outs, y_seq)
6959

70-
#test_adExCell1()
60+
test_adExCell1()

0 commit comments

Comments
 (0)