1+ from ngclearn .components .jaxComponent import JaxComponent
12from jax import numpy as jnp , random , jit , nn
23from functools import partial
3- import time , sys
44from ngclearn .utils import tensorstats
5- from ngclearn import resolver , Component , Compartment
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
67from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
78 step_euler , step_rk2
8- ## import parent cell class/component
9- from ngclearn .components .neurons .spiking .LIFCell import LIFCell
10-
11- @jit
12- def _update_times (t , s , tols ):
13- """
14- Updates time-of-last-spike (tols) variable.
15-
16- Args:
17- t: current time (a scalar/int value)
9+ from ngclearn .utils .surrogate_fx import (secant_lif_estimator , arctan_estimator ,
10+ triangular_estimator ,
11+ straight_through_estimator )
1812
19- s: binary spike vector
13+ from ngcsimlib .compilers .process import transition
14+ #from ngcsimlib.component import Component
15+ from ngcsimlib .compartment import Compartment
2016
21- tols: current time-of-last-spike variable
22-
23- Returns:
24- updated tols variable
25- """
26- _tols = (1. - s ) * tols + (s * t )
27- return _tols
28-
29- @jit
30- def _modify_current (j , dt , tau_m ): ## electrical current re-scaling co-routine
31- jScale = tau_m / dt
32- return j * jScale
17+ from ngclearn .components .neurons .spiking .LIFCell import LIFCell
3318
3419@jit
3520def _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_c , a0 ): ## raw voltage dynamics
36- mask = (rfr >= refract_T ). astype ( jnp . float32 ) # get refractory mask
21+ mask = (rfr >= refract_T ) * 1. # get refractory mask
3722 ## update voltage / membrane potential
3823 dv_dt = ((v_rest - v ) * (v - v_c ) * a0 ) + (j * mask )
3924 dv_dt = dv_dt * (1. / tau_m )
@@ -44,101 +29,18 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
4429 dv_dt = _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_c , a0 )
4530 return dv_dt
4631
47- #@partial(jit, static_argnums=[7,8,9,10,11,12,13,14])
48- def _run_cell (dt , j , v , v_thr , v_theta , rfr , skey , v_c , a0 , tau_m , v_rest ,
49- v_reset , refract_T , integType = 0 ):
50- """
51- Runs quadratic leaky integrator neuronal dynamics
52-
53- Args:
54- dt: integration time constant (milliseconds, or ms)
55-
56- j: electrical current value
57-
58- v: membrane potential (voltage, in milliVolts or mV) value (at t)
59-
60- v_thr: base voltage threshold value (in mV)
61-
62- v_theta: threshold shift (homeostatic) variable (at t)
63-
64- rfr: refractory variable vector (one per neuronal cell)
65-
66- skey: PRNG key which, if not None, will trigger a single-spike constraint
67- (i.e., only one spike permitted to emit per single step of time);
68- specifically used to randomly sample one of the possible action
69- potentials to be an emitted spike
70-
71- v_c: scaling factor for voltage accumulation
72-
73- a0: critical voltage value
74-
75- tau_m: cell membrane time constant
76-
77- v_rest: membrane resting potential (in mV)
78-
79- v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
80- a neuronal cell's membrane potential will be set to this value
81-
82- refract_T: (relative) refractory time period (in ms; Default
83- value is 1 ms)
84-
85- integType: integer indicating type of integration to use
86-
87- Returns:
88- voltage(t+dt), spikes, raw spikes, updated refactory variables
89- """
90- _v_thr = v_theta + v_thr ## calc present voltage threshold
91- #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
92- ## update voltage / membrane potential (v_c ~> 0.8?) (a0 usually <1?)
93- #_v = v + ((v_rest - v) * (v - v_c) * a0) * (dt/tau_m) + (j * mask)
94- v_params = (j , rfr , tau_m , refract_T , v_rest , v_c , a0 )
95- if integType == 1 :
96- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
97- else :
98- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
99- ## obtain action potentials
100- s = (_v > _v_thr ).astype (jnp .float32 )
101- ## update refractory variables
102- _rfr = (rfr + dt ) * (1. - s )
103- ## perform hyper-polarization of neuronal cells
104- _v = _v * (1. - s ) + s * v_reset
105-
106- raw_s = s + 0 ## preserve un-altered spikes
107- ############################################################################
108- ## this is a spike post-processing step
109- if skey is not None : ## FIXME: this would not work for mini-batches!!!!!!!
110- m_switch = (jnp .sum (s ) > 0. ).astype (jnp .float32 )
111- rS = random .choice (skey , s .shape [1 ], p = jnp .squeeze (s ))
112- rS = nn .one_hot (rS , num_classes = s .shape [1 ], dtype = jnp .float32 )
113- s = s * (1. - m_switch ) + rS * m_switch
114- ############################################################################
115- return _v , s , raw_s , _rfr
116-
117- @partial (jit , static_argnums = [3 ,4 ])
32+ #@partial(jit, static_argnums=[3, 4])
11833def _update_theta (dt , v_theta , s , tau_theta , theta_plus = 0.05 ):
119- """
120- Runs homeostatic threshold update dynamics one step.
121-
122- Args:
123- dt: integration time constant (milliseconds, or ms)
124-
125- v_theta: current value of homeostatic threshold variable
126-
127- s: current spikes (at t)
128-
129- tau_theta: homeostatic threshold time constant
130-
131- theta_plus: physical increment to be applied to any threshold value if
132- a spike was emitted
133-
134- Returns:
135- updated homeostatic threshold variable
136- """
34+ ### Runs homeostatic threshold update dynamics one step (via Euler integration).
35+ #theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
36+ #theta_plus = 0.05
37+ #_V_theta = V_theta * theta_decay + S * theta_plus
13738 theta_decay = jnp .exp (- dt / tau_theta )
13839 _v_theta = v_theta * theta_decay + s * theta_plus
40+ #_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
13941 return _v_theta
14042
141- class QuadLIFCell (LIFCell ): ## quadratic (leaky) LIF cell; inherits from LIFCell
43+ class QuadLIFCell (LIFCell ): ## quadratic integrate-and-fire cell
14244 """
14345 A spiking cell based on quadratic leaky integrate-and-fire (LIF) neuronal
14446 dynamics. Note that QuadLIFCell is a child of LIFCell and inherits its
@@ -184,9 +86,9 @@ class QuadLIFCell(LIFCell): ## quadratic (leaky) LIF cell; inherits from LIFCell
18486
18587 v_scale: scaling factor for voltage accumulation (v_c)
18688
187- critical_V : critical voltage value (a0)
89+ critical_v : critical voltage value (in mV) (i.e., variable name - a0)
18890
189- tau_theta: homeostatic threshold time constant
91+ tau_theta: homeostatic threshold time constant
19092
19193 theta_plus: physical increment to be applied to any threshold value if
19294 a spike was emitted
@@ -198,58 +100,138 @@ class QuadLIFCell(LIFCell): ## quadratic (leaky) LIF cell; inherits from LIFCell
198100 a single spike will be permitted to emit per step -- this means that
199101 if > 1 spikes emitted, a single action potential will be randomly
200102 sampled from the non-zero spikes detected
201- """
202-
203- # Define Functions
204- def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
205- v_reset = 60. , v_scale = - 41.6 , critical_V = 1. , tau_theta = 1e7 ,
206- theta_plus = 0.05 , refract_time = 5. , thr_jitter = 0. , one_spike = False ,
207- integration_type = "euler" , ** kwargs ):
208- super ().__init__ (name , n_units , tau_m , resist_m , thr , v_rest , v_reset ,
209- 1. , tau_theta , theta_plus , refract_time , thr_jitter ,
210- one_spike , integration_type , ** kwargs )
103+ """ ## batch_size arg?
104+
105+ @deprecate_args (thr_jitter = None )
106+ def __init__ (
107+ self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. , v_reset = - 60. , v_scale = - 41.6 , critical_v = 1. ,
108+ tau_theta = 1e7 , theta_plus = 0.05 , refract_time = 5. , one_spike = False , integration_type = "euler" ,
109+ surrgoate_type = "straight_through" , lower_clamp_voltage = True , ** kwargs
110+ ):
111+ super ().__init__ (
112+ name , n_units , tau_m , resist_m , thr , v_rest , v_reset , 1. , tau_theta , theta_plus , refract_time ,
113+ one_spike , integration_type , surrgoate_type , lower_clamp_voltage , ** kwargs
114+ )
211115 ## only two distinct additional constants distinguish the Quad-LIF cell
212116 self .v_c = v_scale
213- self .a0 = critical_V
117+ self .a0 = critical_v
214118
119+ @transition (output_compartments = ["v" , "s" , "s_raw" , "rfr" , "thr_theta" , "tols" , "key" , "surrogate" ])
215120 @staticmethod
216- def _advance_state (t , dt , tau_m , R_m , v_rest , v_reset , refract_T , thr ,
217- tau_theta , theta_plus , one_spike , v_c , a0 , intgFlag , key ,
218- j , v , s , rfr , thr_theta , tols ):
219- ## Note: this runs quadratic LIF neuronal dynamics but constrained to be
220- ## similar to the general form of LIF dynamics
121+ def advance_state (
122+ t , dt , tau_m , resist_m , v_rest , v_reset , v_c , a0 , refract_T , thr , tau_theta , theta_plus ,
123+ one_spike , lower_clamp_voltage , intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols
124+ ):
221125 skey = None ## this is an empty dkey if single_spike mode turned off
222- if one_spike : ## old code ~> if self.one_spike is False:
223- key , * subkeys = random .split (key , 2 )
224- skey = subkeys [0 ]
126+ if one_spike :
127+ key , skey = random .split (key , 2 )
225128 ## run one integration step for neuronal dynamics
226- j = j * R_m
227- v , s , raw_spikes , rfr = _run_cell (dt , j , v , thr , thr_theta , rfr , skey ,
228- v_c , a0 , tau_m , v_rest , v_reset ,
229- refract_T , intgFlag )
129+ j = j * resist_m
130+ ############################################################################
131+ ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
132+ _v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
133+ #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
134+ ## update voltage / membrane potential
135+ v_params = (j , rfr , tau_m , refract_T , v_rest , v_c , a0 )
136+ if intgFlag == 1 :
137+ _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
138+ else : #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
139+ _ , _v = step_euler (0. , v , _dfv , dt , v_params )
140+ ## obtain action potentials/spikes
141+ s = (_v > _v_thr ) * 1.
142+ ## update refractory variables
143+ _rfr = (rfr + dt ) * (1. - s )
144+ ## perform hyper-polarization of neuronal cells
145+ _v = _v * (1. - s ) + s * v_reset
146+
147+ raw_s = s + 0 ## preserve un-altered spikes
148+ ############################################################################
149+ ## this is a spike post-processing step
150+ if skey is not None :
151+ m_switch = (jnp .sum (s ) > 0. ).astype (jnp .float32 ) ## TODO: not batch-able
152+ rS = s * random .uniform (skey , s .shape )
153+ rS = nn .one_hot (jnp .argmax (rS , axis = 1 ), num_classes = s .shape [1 ],
154+ dtype = jnp .float32 )
155+ s = s * (1. - m_switch ) + rS * m_switch
156+ ############################################################################
157+ raw_spikes = raw_s
158+ v = _v
159+ rfr = _rfr
160+
161+ surrogate = d_spike_fx (v , _v_thr ) #d_spike_fx(v, thr + thr_theta)
230162 if tau_theta > 0. :
231163 ## run one integration step for threshold dynamics
232- thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta ,
233- theta_plus )
164+ thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta , theta_plus )
234165 ## update tols
235- tols = _update_times (t , s , tols )
236- return v , s , raw_spikes , rfr , thr_theta , tols , key
237-
238- @resolver (_advance_state )
239- def advance_state (self , v , s , s_raw , rfr , thr_theta , tols , key ):
240- self .v .set (v )
241- self .s .set (s )
242- self .s_raw .set (s_raw )
243- self .rfr .set (rfr )
244- self .thr_theta .set (thr_theta )
245- self .tols .set (tols )
246- self .key .set (key )
166+ tols = (1. - s ) * tols + (s * t )
167+ if lower_clamp_voltage : ## ensure voltage never < v_rest
168+ v = jnp .maximum (v , v_rest )
169+ return v , s , raw_spikes , rfr , thr_theta , tols , key , surrogate
170+
171+ @transition (output_compartments = ["j" , "v" , "s" , "s_raw" , "rfr" , "tols" , "surrogate" ])
172+ @staticmethod
173+ def reset (batch_size , n_units , v_rest , refract_T ):
174+ restVals = jnp .zeros ((batch_size , n_units ))
175+ j = restVals #+ 0
176+ v = restVals + v_rest
177+ s = restVals #+ 0
178+ s_raw = restVals
179+ rfr = restVals + refract_T
180+ #thr_theta = restVals ## do not reset thr_theta
181+ tols = restVals #+ 0
182+ surrogate = restVals + 1.
183+ return j , v , s , s_raw , rfr , tols , surrogate
184+
185+ def save (self , directory , ** kwargs ):
186+ ## do a protected save of constants, depending on whether they are floats or arrays
187+ tau_m = (self .tau_m if isinstance (self .tau_m , float )
188+ else jnp .asarray ([[self .tau_m * 1. ]]))
189+ thr = (self .thr if isinstance (self .thr , float )
190+ else jnp .asarray ([[self .thr * 1. ]]))
191+ v_rest = (self .v_rest if isinstance (self .v_rest , float )
192+ else jnp .asarray ([[self .v_rest * 1. ]]))
193+ v_reset = (self .v_reset if isinstance (self .v_reset , float )
194+ else jnp .asarray ([[self .v_reset * 1. ]]))
195+ v_decay = (self .v_decay if isinstance (self .v_decay , float )
196+ else jnp .asarray ([[self .v_decay * 1. ]]))
197+ resist_m = (self .resist_m if isinstance (self .resist_m , float )
198+ else jnp .asarray ([[self .resist_m * 1. ]]))
199+ tau_theta = (self .tau_theta if isinstance (self .tau_theta , float )
200+ else jnp .asarray ([[self .tau_theta * 1. ]]))
201+ theta_plus = (self .theta_plus if isinstance (self .theta_plus , float )
202+ else jnp .asarray ([[self .theta_plus * 1. ]]))
203+
204+ file_name = directory + "/" + self .name + ".npz"
205+ jnp .savez (file_name ,
206+ threshold_theta = self .thr_theta .value ,
207+ tau_m = tau_m , thr = thr , v_rest = v_rest ,
208+ v_reset = v_reset , v_decay = v_decay ,
209+ resist_m = resist_m , tau_theta = tau_theta ,
210+ theta_plus = theta_plus ,
211+ key = self .key .value )
212+
213+ def load (self , directory , seeded = False , ** kwargs ):
214+ file_name = directory + "/" + self .name + ".npz"
215+ data = jnp .load (file_name )
216+ self .thr_theta .set (data ['threshold_theta' ])
217+ ## constants loaded in
218+ self .tau_m = data ['tau_m' ]
219+ self .thr = data ['thr' ]
220+ self .v_rest = data ['v_rest' ]
221+ self .v_reset = data ['v_reset' ]
222+ self .v_decay = data ['v_decay' ]
223+ self .resist_m = data ['resist_m' ]
224+ self .tau_theta = data ['tau_theta' ]
225+ self .theta_plus = data ['theta_plus' ]
226+
227+ if seeded :
228+ self .key .set (data ['key' ])
247229
248230 @classmethod
249231 def help (cls ): ## component help function
250232 properties = {
251- "cell_type" : "QuadLIFCell - evolves neurons according to quadratic "
252- "leaky integrate- and-fire spiking dynamics."
233+ "cell_type" : "LIFCell - evolves neurons according to leaky integrate- "
234+ "and-fire spiking dynamics."
253235 }
254236 compartment_props = {
255237 "inputs" :
@@ -258,6 +240,7 @@ def help(cls): ## component help function
258240 {"v" : "Membrane potential/voltage at time t" ,
259241 "rfr" : "Current state of (relative) refractory variable" ,
260242 "thr" : "Current state of voltage threshold at time t" ,
243+ "thr_theta" : "Current state of homeostatic adaptive threshold at time t" ,
261244 "key" : "JAX PRNG key" },
262245 "outputs" :
263246 {"s" : "Emitted spikes/pulses at time t" ,
@@ -271,14 +254,18 @@ def help(cls): ## component help function
271254 "v_rest" : "Resting membrane potential value" ,
272255 "v_reset" : "Reset membrane potential value" ,
273256 "v_decay" : "Voltage leak/decay factor" ,
274- "v_scale" : "Scaling factor for voltage accumulation" ,
275- "critical_V" : "Critical voltage value" ,
276257 "tau_theta" : "Threshold/homoestatic increment time constant" ,
277- "theta_plus" : "Amount to increment threshold by upon occurrence of spike" ,
258+ "theta_plus" : "Amount to increment threshold by upon occurrence "
259+ "of spike" ,
278260 "refract_time" : "Length of relative refractory period (ms)" ,
279- "thr_jitter" : "Scale of random uniform noise to apply to initial condition of threshold" ,
280- "one_spike" : "Should only one spike be sampled/allowed to emit at any given time step?" ,
281- "integration_type" : "Type of numerical integration to use for the cell dynamics"
261+ "one_spike" : "Should only one spike be sampled/allowed to emit at "
262+ "any given time step?" ,
263+ "integration_type" : "Type of numerical integration to use for the "
264+ "cell dynamics" ,
265+ "surrgoate_type" : "Type of surrogate function to use approximate "
266+ "derivative of spike w.r.t. voltage/current" ,
267+ "lower_bound_clamp" : "Should voltage be lower bounded to be never "
268+ "be below `v_rest`"
282269 }
283270 info = {cls .__name__ : properties ,
284271 "compartments" : compartment_props ,
@@ -301,8 +288,7 @@ def __repr__(self):
301288 return lines
302289
303290if __name__ == '__main__' :
304- # NOTE: VN: currently error in init function
305291 from ngcsimlib .context import Context
306292 with Context ("Bar" ) as bar :
307- X = QuadLIFCell ("X" , 1 , 10. )
293+ X = QuadLIFCell ("X" , 9 , 0.0004 , 3 )
308294 print (X )
0 commit comments