Skip to content

Commit b077ee0

Browse files
author
Alexander Ororbia
committed
revised slif-cell w/ unit-test; cleaned up ode_utils to play nicer w/ new sim-lib
1 parent 6af9fd4 commit b077ee0

File tree

3 files changed

+40
-170
lines changed

3 files changed

+40
-170
lines changed

ngclearn/components/neurons/spiking/sLIFCell.py

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,12 @@ def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
1919
dv_dt = dv_dt * (1./tau_m) * mask
2020
return dv_dt
2121

22+
#@partial(jit, static_argnums=[2])
2223
def _dfv(t, v, params): ## voltage dynamics wrapper
2324
j, rfr, tau_m, refract_T = params
2425
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T)
2526
return dv_dt
2627

27-
@jit
28-
def _hyperpolarize(v, s):
29-
_v = (1. - s) * v ## hyper-polarize cells
30-
return _v
31-
3228
@partial(jit, static_argnums=[3,4,5])
3329
def _update_threshold(dt, v_thr, spikes, thrGain=0.002, thrLeak=0.0005, rho_b = 0.):
3430
## update thresholds if applicable
@@ -51,55 +47,6 @@ def _update_refract_and_spikes(dt, rfr, s, refract_T, sticky_spikes=False):
5147
_s = s * mask + (1. - mask)
5248
return _rfr, _s
5349

54-
@partial(jit, static_argnums=[6, 7, 8, 9, 10, 11])
55-
def _run_cell(dt, j, v, v_thr, tau_m, rfr, spike_fx, refract_T=1., thrGain=0.002,
56-
thrLeak=0.0005, rho_b = 0., sticky_spikes=False, v_min=None):
57-
"""
58-
Runs leaky integrator neuronal dynamics
59-
60-
Args:
61-
dt: integration time constant (milliseconds, or ms)
62-
63-
j: electrical current value
64-
65-
v: membrane potential (voltage) value (at t)
66-
67-
v_thr: voltage threshold value (at t)
68-
69-
tau_m: cell membrane time constant
70-
71-
rfr: refractory variable vector (one per neuronal cell)
72-
73-
spike_fx: spike emission function of form `spike_fx(v, v_thr)`
74-
75-
refract_T: (relative) refractory time period (in ms; Default
76-
value is 1 ms)
77-
78-
thrGain: the amount of threshold incremented per time step (if spike present)
79-
80-
thrLeak: the amount of threshold value leaked per time step
81-
82-
rho_b: sparsity factor; if > 0, will force adaptive threshold to operate
83-
with sparsity across a layer enforced
84-
85-
sticky_spikes: if True, then spikes are pinned at value of action potential
86-
(i.e., 1) for as long as the relative refractory occurs (this recovers
87-
the source paper's core spiking process)
88-
89-
Returns:
90-
voltage(t+dt), spikes, threshold(t+dt), updated refactory variables
91-
"""
92-
#new_voltage, mask = _update_voltage(dt, j, v, rfr, tau_m, refract_T, v_min)
93-
v_params = (j, rfr, tau_m, refract_T)
94-
_, _v = step_euler(0., v, _dfv, dt, v_params) #_v = step_euler(v, v_params, _dfv, dt)
95-
# if v_min is not None:
96-
# _v = jnp.maximum(v_min, _v)
97-
spikes = spike_fx(_v, v_thr)
98-
_v = _hyperpolarize(_v, spikes)
99-
new_thr = _update_threshold(dt, v_thr, spikes, thrGain, thrLeak, rho_b)
100-
_rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, refract_T, sticky_spikes)
101-
return _v, spikes, new_thr, _rfr
102-
10350
class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell
10451
"""
10552
A spiking cell based on a simplified leaky integrate-and-fire (sLIF) model.
@@ -237,14 +184,20 @@ def advance_state(
237184
if inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied
238185
j = j - (jnp.matmul(spikes, inh_weights) * inh_R)
239186
#####################################################################################
240-
241-
surrogate = d_spike_fx(j, c1=0.82, c2=0.08)
242-
#surrogate = d_spike_fx(j_curr, c1=0.82, c2=0.08)
243-
244-
v, s, thr, rfr = \
245-
_run_cell(dt, j, v, thr, tau_m,
246-
rfr, spike_fx, refract_T, thrGain, thrLeak,
247-
rho_b, sticky_spikes=sticky_spikes, v_min=v_min)
187+
surrogate = d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes
188+
189+
## transition to: voltage(t+dt), spikes, threshold(t+dt), refractory_variables(t+dt)
190+
v_params = (j, rfr, tau_m, refract_T)
191+
_, _v = step_euler(0., v, _dfv, dt, v_params)
192+
spikes = spike_fx(_v, thr)
193+
#_v = _hyperpolarize(_v, spikes)
194+
_v = (1. - spikes) * _v ## hyper-polarize cells
195+
new_thr = _update_threshold(dt, thr, spikes, thrGain, thrLeak, rho_b)
196+
_rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, refract_T, sticky_spikes)
197+
v = _v
198+
s = spikes
199+
thr = new_thr
200+
rfr = _rfr
248201

249202
## update tols
250203
tols = (1. - s) * tols + (s * t)

ngclearn/utils/diffeq/ode_utils.py

Lines changed: 23 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from jax.lax import scan as _scan
1515
import time, sys
1616

17-
def get_integrator_code(integrationType):
17+
def get_integrator_code(integrationType): ## integrator type decoding routine
1818
"""
1919
Convenience function for mapping integrator type string to ngc-learn's
2020
internal integer code value.
@@ -42,22 +42,20 @@ def get_integrator_code(integrationType):
4242
to RK-1/Euler routine".format(integrationType))
4343
return intgFlag
4444

45-
4645
@jit
47-
def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition
48-
sum = 0
49-
46+
def _sum_combine(*args, **kwargs): ## fast co-routine for simple addition/summation
47+
_sum = 0
5048
for arg, val in zip(args, kwargs.values()):
51-
sum = sum + val * arg
52-
return sum
49+
_sum = _sum + val * arg
50+
return _sum
5351

5452
@jit
5553
def _step_forward(t, x, dx_dt, dt, x_scale): ## internal step co-routine
5654
_t = t + dt
5755
_x = x * x_scale + dx_dt * dt
5856
return _t, _x
5957

60-
@partial(jit, static_argnums=(2, 3, 4, 5, ))
58+
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5)
6159
def step_euler(t, x, dfx, dt, params, x_scale=1.):
6260
"""
6361
Iteratively integrates one step forward via the Euler method, i.e., a
@@ -81,14 +79,12 @@ def step_euler(t, x, dfx, dt, params, x_scale=1.):
8179
Returns:
8280
variable values iteratively integrated/advanced to next step (`t + dt`)
8381
"""
84-
8582
carry = (t, x)
8683
next_state, *_ = _euler(carry, dfx, dt, params, x_scale=x_scale)
8784
_t, _x = next_state
88-
8985
return _t, _x
9086

91-
@partial(jit, static_argnums=(1, 2, 3, 4, ))
87+
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4))
9288
def _euler(carry, dfx, dt, params, x_scale=1.):
9389
"""
9490
Iteratively integrates one step forward via the Euler method, i.e., a
@@ -111,17 +107,12 @@ def _euler(carry, dfx, dt, params, x_scale=1.):
111107
variable values iteratively integrated/advanced to next step (`t + dt`)
112108
"""
113109
t, x = carry
114-
115110
dx_dt = dfx(t, x, params)
116111
_t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
117-
118112
new_carry = (_t, _x)
119113
return new_carry, (new_carry, carry)
120114

121-
122-
123-
124-
115+
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
125116
def step_heun(t, x, dfx, dt, params, x_scale=1.):
126117
"""
127118
Iteratively integrates one step forward via Heun's method, i.e., a
@@ -155,23 +146,11 @@ def step_heun(t, x, dfx, dt, params, x_scale=1.):
155146
"""
156147

157148
carry = (t, x)
158-
159149
next_state, *_ = _heun(carry, dfx, dt, params, x_scale=x_scale)
160-
161-
#
162-
# dx_dt = dfx(t, x, params)
163-
#
164-
# _t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
165-
# _dx_dt = dfx(_t, _x, params)
166-
# summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
167-
168-
# _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
169150
_t, _x = next_state
170-
171151
return _t, _x
172152

173-
174-
@partial(jit, static_argnums=(1, 2, 3, 4, ))
153+
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4, ))
175154
def _heun(carry, dfx, dt, params, x_scale=1.):
176155
"""
177156
Iteratively integrates one step forward via Heun's method, i.e., a
@@ -202,19 +181,15 @@ def _heun(carry, dfx, dt, params, x_scale=1.):
202181
variable values iteratively integrated/advanced to next step (`t + dt`)
203182
"""
204183
t, x = carry
205-
206184
dx_dt = dfx(t, x, params)
207185
_t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
208186
_dx_dt = dfx(_t, _x, params)
209187
summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
210188
_, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
211-
212189
new_carry = (_t, _x)
213190
return new_carry, (new_carry, carry)
214191

215-
216-
217-
192+
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
218193
def step_rk2(t, x, dfx, dt, params, x_scale=1.):
219194
"""
220195
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -244,30 +219,12 @@ def step_rk2(t, x, dfx, dt, params, x_scale=1.):
244219
Returns:
245220
variable values iteratively integrated/advanced to next step (`t + dt`)
246221
"""
247-
248222
carry = (t, x)
249223
next_state, *_ = _rk2(carry, dfx, dt, params, x_scale=x_scale)
250224
_t, _x = next_state
251-
252-
#
253-
# dx_dt = dfx(t, x, params)
254-
#
255-
# _t, _x = _step_forward(t, x, dx_dt, dt, x_scale)
256-
# _dx_dt = dfx(_t, _x, params)
257-
# summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=1, weight2=1)
258-
259-
# _, _x = _step_forward(t, x, summed_dx_dt, dt * 0.5, x_scale)
260-
261-
262-
# dfx_1 = dfx(t, x, params)
263-
#
264-
# t1, x1 = _step_forward(t, x, dfx_1, dt * 0.5, x_scale)
265-
# dfx_2 = dfx(t1, x1, params)
266-
# _t, _x = _step_forward(t, x, dfx_2, dt, x_scale)
267225
return _t, _x
268226

269-
270-
@partial(jit, static_argnums=(1, 2, 3, 4, ))
227+
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4, ))
271228
def _rk2(carry, dfx, dt, params, x_scale=1.):
272229
"""
273230
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -296,22 +253,19 @@ def _rk2(carry, dfx, dt, params, x_scale=1.):
296253
variable values iteratively integrated/advanced to next step (`t + dt`)
297254
"""
298255
t, x = carry
299-
300256
f_1 = dfx(t, x, params)
301257
t1, x1 = _step_forward(t, x, f_1, dt * 0.5, x_scale)
302258
f_2 = dfx(t1, x1, params)
303259
_t, _x = _step_forward(t, x, f_2, dt, x_scale)
304-
305260
new_carry = (_t, _x)
306261
return new_carry, (new_carry, carry)
307262

308-
309-
263+
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
310264
def step_rk4(t, x, dfx, dt, params, x_scale=1.):
311265
"""
312266
Iteratively integrates one step forward via the midpoint method, i.e., a
313267
fourth-order Runge-Kutta (RK-4) step.
314-
(Note: ngc-learn internally recognizes "rk4" or this routine)
268+
(Note: ngc-learn internally recognizes "rk4" for this routine)
315269
316270
| Reference:
317271
| Ascher, Uri M., and Linda R. Petzold. Computer methods for ordinary
@@ -339,25 +293,9 @@ def step_rk4(t, x, dfx, dt, params, x_scale=1.):
339293
carry = (t, x)
340294
next_state, *_ = _rk4(carry, dfx, dt, params, x_scale=x_scale)
341295
_t, _x = next_state
342-
343-
# dfx_1 = dfx(t, x, params)
344-
# t2, x2 = _step_forward(t, x, dfx_1, dt * 0.5, x_scale)
345-
#
346-
# dfx_2 = dfx(t2, x2, params)
347-
# t3, x3 = _step_forward(t, x, dfx_2, dt * 0.5, x_scale)
348-
#
349-
# dfx_3 = dfx(t3, x3, params)
350-
# t4, x4 = _step_forward(t, x, dfx_3, dt, x_scale)
351-
#
352-
# dfx_4 = dfx(t4, x4, params)
353-
#
354-
# _dx_dt = _sum_combine(dfx_1, dfx_2, dfx_3, dfx_4, w_f1=1, w_f2=2, w_f3=2, w_f4=1)
355-
# _t, _x = _step_forward(t, x, _dx_dt / 6, dt, x_scale)
356296
return _t, _x
357297

358-
359-
360-
@partial(jit, static_argnums=(1, 2, 3, 4, ))
298+
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4, ))
361299
def _rk4(carry, dfx, dt, params, x_scale=1.):
362300
"""
363301
Iteratively integrates one step forward via the midpoint method, i.e., a
@@ -385,28 +323,22 @@ def _rk4(carry, dfx, dt, params, x_scale=1.):
385323
Returns:
386324
variable values iteratively integrated/advanced to next step (`t + dt`)
387325
"""
388-
389326
t, x = carry
390-
391-
dfx_1 = dfx(t, x, params)
327+
## carry out 4 steps of RK-4
328+
dfx_1 = dfx(t, x, params) ## k1
392329
t2, x2 = _step_forward(t, x, dfx_1, dt * 0.5, x_scale)
393-
394-
dfx_2 = dfx(t2, x2, params)
330+
dfx_2 = dfx(t2, x2, params) ## k2
395331
t3, x3 = _step_forward(t, x, dfx_2, dt * 0.5, x_scale)
396-
397-
dfx_3 = dfx(t3, x3, params)
332+
dfx_3 = dfx(t3, x3, params) ## k3
398333
t4, x4 = _step_forward(t, x, dfx_3, dt, x_scale)
399-
400-
dfx_4 = dfx(t4, x4, params)
401-
334+
dfx_4 = dfx(t4, x4, params) ## k4
335+
## produce final estimate and move forward
402336
_dx_dt = _sum_combine(dfx_1, dfx_2, dfx_3, dfx_4, w_f1=1, w_f2=2, w_f3=2, w_f4=1)
403337
_t, _x = _step_forward(t, x, _dx_dt / 6, dt, x_scale)
404-
405338
new_carry = (_t, _x)
406339
return new_carry, (new_carry, carry)
407340

408-
409-
341+
@partial(jit, static_argnums=(2, 3, 5)) #(2, 3, 4, 5))
410342
def step_ralston(t, x, dfx, dt, params, x_scale=1.):
411343
"""
412344
Iteratively integrates one step forward via Ralston's method, i.e., a
@@ -438,22 +370,12 @@ def step_ralston(t, x, dfx, dt, params, x_scale=1.):
438370
Returns:
439371
variable values iteratively integrated/advanced to next step (`t + dt`)
440372
"""
441-
442373
carry = (t, x)
443-
next_state, *_ = _rk4(carry, dfx, dt, params, x_scale=x_scale)
374+
next_state, *_ = _ralston(carry, dfx, dt, params, x_scale=x_scale)
444375
_t, _x = next_state
445-
446-
# dx_dt = dfx(t, x, params) ## k1
447-
# tm, xm = _step_forward(t, x, dx_dt, dt * 0.75, x_scale)
448-
# _dx_dt = dfx(tm, xm, params) ## k2
449-
# ## Note: new step is a weighted combination of k1 and k2
450-
# summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=(1./3.), weight2=(2./3.))
451-
# _t, _x = _step_forward(t, x, summed_dx_dt, dt, x_scale)
452376
return _t, _x
453377

454-
455-
456-
@partial(jit, static_argnums=(1, 2, 3, 4,))
378+
@partial(jit, static_argnums=(1, 2, 4)) #(1, 2, 3, 4,))
457379
def _ralston(carry, dfx, dt, params, x_scale=1.):
458380
"""
459381
Iteratively integrates one step forward via Ralston's method, i.e., a
@@ -485,22 +407,18 @@ def _ralston(carry, dfx, dt, params, x_scale=1.):
485407
"""
486408

487409
t, x = carry
488-
489410
dx_dt = dfx(t, x, params) ## k1
490411
tm, xm = _step_forward(t, x, dx_dt, dt * 0.75, x_scale)
491412
_dx_dt = dfx(tm, xm, params) ## k2
492413
## Note: new step is a weighted combination of k1 and k2
493414
summed_dx_dt = _sum_combine(dx_dt, _dx_dt, weight1=(1./3.), weight2=(2./3.))
494415
_t, _x = _step_forward(t, x, summed_dx_dt, dt, x_scale)
495-
496416
new_carry = (_t, _x)
497417
return new_carry, (new_carry, carry)
498418

499419

500-
501420
@partial(jit, static_argnums=(0, 3, 4, 5, 6, 7, 8))
502421
def solve_ode(method_name, t0, x0, T, dfx, dt, params=None, x_scale=1., sols_only=True):
503-
504422
if method_name =='euler':
505423
method = _euler
506424
elif method_name == 'heun':

0 commit comments

Comments
 (0)