@@ -123,15 +123,14 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
123123 "arctan" (arc-tangent estimator), and "secant_lif" (the
124124 LIF-specialized secant estimator)
125125
126- lower_clamp_voltage: if True, this will ensure voltage never is below
127- the value of `v_rest` (default: True)
126+ v_min: minimum voltage to clamp dynamics to (Default: None)
128127 """ ## batch_size arg?
129128
130129 @deprecate_args (thr_jitter = None , v_decay = "conduct_leak" )
131130 def __init__ (
132131 self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. , v_reset = - 60. , conduct_leak = 1. , tau_theta = 1e7 ,
133132 theta_plus = 0.05 , refract_time = 5. , one_spike = False , integration_type = "euler" , surrogate_type = "straight_through" ,
134- lower_clamp_voltage = True , ** kwargs
133+ v_min = None , max_one_spike = False , ** kwargs
135134 ):
136135 super ().__init__ (name , ** kwargs )
137136
@@ -143,7 +142,8 @@ def __init__(
143142 self .tau_m = tau_m ## membrane time constant
144143 self .resist_m = resist_m ## resistance value
145144 self .one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
146- self .lower_clamp_voltage = lower_clamp_voltage ## True ==> ensures voltage is never < v_rest
145+ self .max_one_spike = max_one_spike
146+ self .v_min = v_min ## ensures voltage is never < v_min
147147
148148 self .v_rest = v_rest #-65. # mV
149149 self .v_reset = v_reset # -60. # -65. # mV (milli-volts)
@@ -189,11 +189,11 @@ def __init__(
189189 @transition (output_compartments = ["v" , "s" , "s_raw" , "rfr" , "thr_theta" , "tols" , "key" , "surrogate" ])
190190 @staticmethod
191191 def advance_state (
192- t , dt , tau_m , resist_m , v_rest , v_reset , g_L , refract_T , thr , tau_theta , theta_plus ,
193- one_spike , lower_clamp_voltage , intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols
192+ t , dt , tau_m , resist_m , v_rest , v_reset , g_L , refract_T , thr , tau_theta , theta_plus , one_spike , max_one_spike ,
193+ v_min , intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols
194194 ):
195195 skey = None ## this is an empty dkey if single_spike mode turned off
196- if one_spike :
196+ if one_spike and not max_one_spike :
197197 key , skey = random .split (key , 2 )
198198 ## run one integration step for neuronal dynamics
199199 j = j * resist_m
@@ -209,6 +209,7 @@ def advance_state(
209209 _ , _v = step_euler (0. , v , _dfv , dt , v_params )
210210 ## obtain action potentials/spikes/pulses
211211 s = (_v > _v_thr ) * 1.
212+ v_prespike = v
212213 ## update refractory variables
213214 _rfr = (rfr + dt ) * (1. - s )
214215 ## perform hyper-polarization of neuronal cells
@@ -223,6 +224,9 @@ def advance_state(
223224 rS = nn .one_hot (jnp .argmax (rS , axis = 1 ), num_classes = s .shape [1 ],
224225 dtype = jnp .float32 )
225226 s = s * (1. - m_switch ) + rS * m_switch
227+ if max_one_spike :
228+ rS = nn .one_hot (jnp .argmax (v_prespike , axis = 1 ), num_classes = s .shape [1 ], dtype = jnp .float32 ) ## get max-volt spike
229+ s = s * rS ## mask out non-max volt spikes
226230 ############################################################################
227231 raw_spikes = raw_s
228232 v = _v
@@ -234,8 +238,8 @@ def advance_state(
234238 thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta , theta_plus )
235239 ## update tols
236240 tols = (1. - s ) * tols + (s * t )
237- if lower_clamp_voltage : ## ensure voltage never < v_rest
238- v = jnp .maximum (v , v_rest )
241+ if v_min is not None : ## ensures voltage never < v_rest
242+ v = jnp .maximum (v , v_min )
239243 return v , s , raw_spikes , rfr , thr_theta , tols , key , surrogate
240244
241245 @transition (output_compartments = ["j" , "v" , "s" , "s_raw" , "rfr" , "tols" , "surrogate" ])
0 commit comments