1414#from ngcsimlib.component import Component
1515from ngcsimlib .compartment import Compartment
1616
17- #@jit
18- def _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_decay = 1. ): ## raw voltage dynamics
19- mask = (rfr >= refract_T ) * 1. # get refractory mask
20- ## update voltage / membrane potential
21- dv_dt = (v_rest - v ) * v_decay + (j * mask )
22- dv_dt = dv_dt * (1. / tau_m )
23- return dv_dt
17+ # def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
18+ # mask = (rfr >= refract_T) * 1. # get refractory mask
19+ # ## update voltage / membrane potential
20+ # dv_dt = (v_rest - v) * v_decay + (j * mask)
21+ # dv_dt = dv_dt * (1./tau_m)
22+ # return dv_dt
23+ #
24+ # def _dfv(t, v, params): ## voltage dynamics wrapper
25+ # j, rfr, tau_m, refract_T, v_rest, v_decay = params
26+ # dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay)
27+ # return dv_dt
28+
29+
2430
2531def _dfv (t , v , params ): ## voltage dynamics wrapper
26- j , rfr , tau_m , refract_T , v_rest , v_decay = params
27- dv_dt = _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_decay )
32+ j , rfr , tau_m , refract_T , v_rest , g_L = params
33+ mask = (rfr >= refract_T ) * 1. # get refractory mask
34+ ## update voltage / membrane potential
35+ dv_dt = (v_rest - v ) * g_L + (j * mask )
36+ dv_dt = dv_dt * (1. / tau_m )
2837 return dv_dt
2938
39+
3040#@partial(jit, static_argnums=[3, 4])
3141def _update_theta (dt , v_theta , s , tau_theta , theta_plus = 0.05 ):
3242 ### Runs homeostatic threshold update dynamics one step (via Euler integration).
@@ -38,6 +48,7 @@ def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
3848 #_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
3949 return _v_theta
4050
51+
4152class LIFCell (JaxComponent ): ## leaky integrate-and-fire cell
4253 """
4354 A spiking cell based on leaky integrate-and-fire (LIF) neuronal dynamics.
@@ -73,14 +84,14 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
7384 thr: base value for adaptive thresholds that govern short-term
7485 plasticity (in milliVolts, or mV; default: -52. mV)
7586
76- v_rest: membrane resting potential (in mV; default: -65 mV)
87+ v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV)
7788
7889 v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
7990 a neuronal cell's membrane potential will be set to this value;
8091 (default: -60 mV)
8192
82- v_decay: decay factor applied to voltage leak (Default: 1.); setting this
83- to 0 mV recovers pure integrate-and-fire (IF) dynamics
93+ conduct_leak: leak conductance (g_L) value or decay factor applied to voltage leak
94+ (Default: 1.); setting this to 0 mV recovers pure integrate-and-fire (IF) dynamics
8495
8596 tau_theta: homeostatic threshold time constant
8697
@@ -116,12 +127,12 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
116127 the value of `v_rest` (default: True)
117128 """ ## batch_size arg?
118129
119- @deprecate_args (thr_jitter = None )
120- def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
121- v_reset = - 60. , v_decay = 1. , tau_theta = 1e7 , theta_plus = 0.05 ,
122- refract_time = 5. , one_spike = False , integration_type = "euler" ,
123- surrogate_type = "straight_through" , lower_clamp_voltage = True ,
124- ** kwargs ):
130+ @deprecate_args (thr_jitter = None , v_decay = "conduct_leak" )
131+ def __init__ (
132+ self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. , v_reset = - 60. , conduct_leak = 1. , tau_theta = 1e7 ,
133+ theta_plus = 0.05 , refract_time = 5. , one_spike = False , integration_type = "euler" , surrogate_type = "straight_through " ,
134+ lower_clamp_voltage = True , ** kwargs
135+ ):
125136 super ().__init__ (name , ** kwargs )
126137
127138 ## Integration properties
@@ -136,7 +147,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
136147
137148 self .v_rest = v_rest #-65. # mV
138149 self .v_reset = v_reset # -60. # -65. # mV (milli-volts)
139- self .v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF)
150+ self .g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF)
140151 ## basic asserts to prevent neuronal dynamics breaking...
141152 #assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
142153 assert self .resist_m > 0.
@@ -178,7 +189,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
178189 @transition (output_compartments = ["v" , "s" , "s_raw" , "rfr" , "thr_theta" , "tols" , "key" , "surrogate" ])
179190 @staticmethod
180191 def advance_state (
181- t , dt , tau_m , resist_m , v_rest , v_reset , v_decay , refract_T , thr , tau_theta , theta_plus ,
192+ t , dt , tau_m , resist_m , v_rest , v_reset , g_L , refract_T , thr , tau_theta , theta_plus ,
182193 one_spike , lower_clamp_voltage , intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols
183194 ):
184195 skey = None ## this is an empty dkey if single_spike mode turned off
@@ -191,7 +202,7 @@ def advance_state(
191202 _v_thr = thr_theta + thr ## calc present voltage threshold
192203 #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
193204 ## update voltage / membrane potential
194- v_params = (j , rfr , tau_m , refract_T , v_rest , v_decay )
205+ v_params = (j , rfr , tau_m , refract_T , v_rest , g_L )
195206 if intgFlag == 1 :
196207 _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
197208 else :
0 commit comments