@@ -28,22 +28,6 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
2828 dv_dt = _dfv_internal (j , v , rfr , tau_m , refract_T )
2929 return dv_dt
3030
31- def _run_cell (dt , j , v , v_thr , rfr , tau_m , v_rest , v_reset , refract_T , integType = 0 ):
32- ### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
33- ## update voltage / membrane potential
34- v_params = (j , rfr , tau_m , refract_T )
35- if integType == 1 :
36- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
37- else :
38- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
39- ## obtain action potentials/spikes
40- s = (_v > v_thr ).astype (jnp .float32 )
41- ## update refractory variables
42- _rfr = (rfr + dt ) * (1. - s )
43- ## perform hyper-polarization of neuronal cells
44- _v = _v * (1. - s ) + s * v_reset
45- return _v , s , _rfr
46-
4731class IFCell (JaxComponent ): ## integrate-and-fire cell
4832 """
4933 A spiking cell based on integrate-and-fire (IF) neuronal dynamics.
@@ -162,8 +146,21 @@ def advance_state(
162146 ):
163147 ## run one integration step for neuronal dynamics
164148 j = j * resist_m
165- v , s , rfr = _run_cell (dt , j , v , thr , rfr , tau_m , v_rest , v_reset ,
166- refract_T , intgFlag )
149+
150+ ### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
151+ ## update voltage / membrane potential
152+ v_params = (j , rfr , tau_m , refract_T )
153+ if intgFlag == 1 :
154+ _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
155+ else :
156+ _ , _v = step_euler (0. , v , _dfv , dt , v_params )
157+ ## obtain action potentials/spikes
158+ s = (_v > thr ).astype (jnp .float32 )
159+ ## update refractory variables
160+ rfr = (rfr + dt ) * (1. - s )
161+ ## perform hyper-polarization of neuronal cells
162+ v = _v * (1. - s ) + s * v_reset
163+
167164 surrogate = d_spike_fx (v , thr )
168165 ## update tols
169166 tols = (1. - s ) * tols + (s * t )
0 commit comments