1+ from ngclearn .components .jaxComponent import JaxComponent
12from jax import numpy as jnp , random , jit
23from functools import partial
3- from ngclearn import resolver , Component , Compartment
4- from ngclearn .components .jaxComponent import JaxComponent
4+ from ngclearn .utils import tensorstats
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
57from ngclearn .utils .diffeq .ode_utils import step_euler
68from ngclearn .utils .surrogate_fx import secant_lif_estimator
7- from ngclearn .utils import tensorstats
8-
9- @jit
10- def _update_times (t , s , tols ):
11- """
12- Updates time-of-last-spike (tols) variable.
13-
14- Args:
15- t: current time (a scalar/int value)
16-
17- s: binary spike vector
18-
19- tols: current time-of-last-spike variable
20-
21- Returns:
22- updated tols variable
23- """
24- _tols = (1. - s ) * tols + (s * t )
25- return _tols
26-
27- @partial (jit , static_argnums = [3 ,4 ])
28- def _modify_current (j , spikes , inh_weights , R_m , inh_R ):
29- """
30- A simple function that modifies electrical current j via application of a
31- scalar membrane resistance value and an approximate form of lateral inhibition.
32- Note that if no inhibitory resistance is set (i.e., inh_R = 0), then no
33- lateral inhibition is applied. Functionally, this routine carries out the
34- following piecewise equation:
35-
36- | j * R_m - [Wi * s(t-dt)] * inh_R, if inh_R > 0
37- | j * R_m, otherwise
38-
39- Args:
40- j: electrical current value
41-
42- spikes: previous binary spike vector (for t-dt)
43-
44- inh_weights: lateral recurrent inhibitory synapses (typically should be
45- chosen to be a scaled hollow matrix)
46-
47- R_m: membrane resistance (to multiply/scale j by)
489
49- inh_R: inhibitory resistance to scale lateral inhibitory current by; if
50- inh_R = 0, NO lateral inhibitory pressure will be applied
51-
52- Returns:
53- modified electrical current value
54- """
55- _j = j * R_m
56- if inh_R > 0. :
57- _j = _j - (jnp .matmul (spikes , inh_weights ) * inh_R )
58- return _j
10+ from ngcsimlib .compilers .process import transition
11+ #from ngcsimlib.component import Component
12+ from ngcsimlib .compartment import Compartment
5913
6014@jit
6115def _dfv_internal (j , v , rfr , tau_m , refract_T ): ## raw voltage dynamics
@@ -97,6 +51,7 @@ def _update_refract_and_spikes(dt, rfr, s, refract_T, sticky_spikes=False):
9751 _s = s * mask + (1. - mask )
9852 return _rfr , _s
9953
54+ @partial (jit , static_argnums = [6 , 7 , 8 , 9 , 10 , 11 ])
10055def _run_cell (dt , j , v , v_thr , tau_m , rfr , spike_fx , refract_T = 1. , thrGain = 0.002 ,
10156 thrLeak = 0.0005 , rho_b = 0. , sticky_spikes = False , v_min = None ):
10257 """
@@ -206,6 +161,8 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell
206161 a key setting used by Samadi et al., 2017
207162
208163 thr_jitter: scale of uniform jitter to add to initialization of thresholds
164+
165+ batch_size: batch size dimension of this cell (Default: 1)
209166 """
210167
211168 # Define Functions
@@ -258,36 +215,44 @@ def __init__(self, name, n_units, tau_m, resist_m, thr, resist_inh=0.,
258215 self .rfr = Compartment (restVals + self .refract_T ) ## refractory variable(s)
259216 self .surrogate = Compartment (restVals + 1. ) ## surrogate signal
260217
218+ @transition (output_compartments = ["j" , "s" , "tols" , "v" , "thr" , "rfr" , "surrogate" ])
261219 @staticmethod
262- def _advance_state (t , dt , inh_weights , R_m , inh_R , d_spike_fx , tau_m ,
263- spike_fx , refract_T , thrGain , thrLeak , rho_b ,
264- sticky_spikes , v_min , j , s , v , thr , rfr , tols ):
265- ## run one step of Euler integration over neuronal dynamics
266- j_curr = j
267- ## apply simplified inhibitory pressure
268- j_curr = _modify_current (j_curr , s , inh_weights , R_m , inh_R )
269- j = j_curr # None ## store electrical current
270- surrogate = d_spike_fx (j_curr , c1 = 0.82 , c2 = 0.08 )
220+ def advance_state (
221+ t , dt , inh_weights , R_m , inh_R , d_spike_fx , tau_m , spike_fx , refract_T , thrGain ,
222+ thrLeak , rho_b , sticky_spikes , v_min , j , s , v , thr , rfr , tols
223+ ):
224+ #####################################################################################
225+ #The following 3 lines of code modify electrical current j via application of a
226+ #scalar membrane resistance value and an approximate form of lateral inhibition.
227+ #Functionally, this routine carries out the following piecewise equation:
228+ #| j * R_m - [Wi * s(t-dt)] * inh_R, if inh_R > 0
229+ #| j * R_m, otherwise
230+ #| where j: electrical current value, spikes: previous binary spike vector (for t-dt),
231+ # inh_weights: lateral recurrent inhibitory synapses (typically should be chosen
232+ # to be a scaled hollow matrix),
233+ #| R_m: membrane resistance (to multiply/scale j by),
234+ #| inh_R: inhibitory resistance to scale lateral inhibitory current by; if inh_R = 0,
235+ # NO lateral inhibitory pressure will be applied
236+ j = j * R_m
237+ if inh_R > 0. : ## if inh_R > 0, then lateral inhibition is applied
238+ j = j - (jnp .matmul (spikes , inh_weights ) * inh_R )
239+ #####################################################################################
240+
241+ surrogate = d_spike_fx (j , c1 = 0.82 , c2 = 0.08 )
242+ #surrogate = d_spike_fx(j_curr, c1=0.82, c2=0.08)
243+
271244 v , s , thr , rfr = \
272- _run_cell (dt , j_curr , v , thr , tau_m ,
245+ _run_cell (dt , j , v , thr , tau_m ,
273246 rfr , spike_fx , refract_T , thrGain , thrLeak ,
274247 rho_b , sticky_spikes = sticky_spikes , v_min = v_min )
248+
275249 ## update tols
276- tols = _update_times ( t , s , tols )
250+ tols = ( 1. - s ) * tols + ( s * t )
277251 return j , s , tols , v , thr , rfr , surrogate
278252
279- @resolver (_advance_state )
280- def advance_state (self , j , s , tols , v , thr , rfr , surrogate ):
281- self .j .set (j )
282- self .s .set (s )
283- self .tols .set (tols )
284- self .thr .set (thr )
285- self .rfr .set (rfr )
286- self .surrogate .set (surrogate )
287- self .v .set (v )
288-
253+ @transition (output_compartments = ["j" , "s" , "tols" , "v" , "thr" , "rfr" , "surrogate" ])
289254 @staticmethod
290- def _reset (refract_T , thr_persist , threshold0 , batch_size , n_units , thr ):
255+ def reset (refract_T , thr_persist , threshold0 , batch_size , n_units , thr ):
291256 restVals = jnp .zeros ((batch_size , n_units ))
292257 voltage = restVals
293258 refract = restVals + refract_T
@@ -299,16 +264,6 @@ def _reset(refract_T, thr_persist, threshold0, batch_size, n_units, thr):
299264 thr = threshold0 + 0
300265 return current , spikes , timeOfLastSpike , voltage , thr , refract , surrogate
301266
302- @resolver (_reset )
303- def reset (self , j , s , tols , v , thr , rfr , surrogate ):
304- self .j .set (j )
305- self .s .set (s )
306- self .tols .set (tols )
307- self .thr .set (thr )
308- self .rfr .set (rfr )
309- self .surrogate .set (surrogate )
310- self .v .set (v )
311-
312267 def save (self , directory , ** kwargs ):
313268 file_name = directory + "/" + self .name + ".npz"
314269 if self .thr_persist == False :
0 commit comments