Skip to content

Commit 3c1ee98

Browse files
author
Alexander Ororbia
committed
revised if-cell w/ unit-test
1 parent f4e661c commit 3c1ee98

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,6 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
2828
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T)
2929
return dv_dt
3030

31-
def _run_cell(dt, j, v, v_thr, rfr, tau_m, v_rest, v_reset, refract_T, integType=0):
32-
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
33-
## update voltage / membrane potential
34-
v_params = (j, rfr, tau_m, refract_T)
35-
if integType == 1:
36-
_, _v = step_rk2(0., v, _dfv, dt, v_params)
37-
else:
38-
_, _v = step_euler(0., v, _dfv, dt, v_params)
39-
## obtain action potentials/spikes
40-
s = (_v > v_thr).astype(jnp.float32)
41-
## update refractory variables
42-
_rfr = (rfr + dt) * (1. - s)
43-
## perform hyper-polarization of neuronal cells
44-
_v = _v * (1. - s) + s * v_reset
45-
return _v, s, _rfr
46-
4731
class IFCell(JaxComponent): ## integrate-and-fire cell
4832
"""
4933
A spiking cell based on integrate-and-fire (IF) neuronal dynamics.
@@ -162,8 +146,21 @@ def advance_state(
162146
):
163147
## run one integration step for neuronal dynamics
164148
j = j * resist_m
165-
v, s, rfr = _run_cell(dt, j, v, thr, rfr, tau_m, v_rest, v_reset,
166-
refract_T, intgFlag)
149+
150+
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
151+
## update voltage / membrane potential
152+
v_params = (j, rfr, tau_m, refract_T)
153+
if intgFlag == 1:
154+
_, _v = step_rk2(0., v, _dfv, dt, v_params)
155+
else:
156+
_, _v = step_euler(0., v, _dfv, dt, v_params)
157+
## obtain action potentials/spikes
158+
s = (_v > thr).astype(jnp.float32)
159+
## update refractory variables
160+
rfr = (rfr + dt) * (1. - s)
161+
## perform hyper-polarization of neuronal cells
162+
v = _v * (1. - s) + s * v_reset
163+
167164
surrogate = d_spike_fx(v, thr)
168165
## update tols
169166
tols = (1. - s) * tols + (s * t)

0 commit comments

Comments
 (0)