1+ """
12from jax import numpy as jnp, random, jit, nn
23from functools import partial
34from ngclearn.utils import tensorstats
45from ngcsimlib.deprecators import deprecate_args
56from ngclearn import resolver, Component, Compartment
67from ngclearn.components.jaxComponent import JaxComponent
8+ from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
9+ step_euler, step_rk2
10+ from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
11+ triangular_estimator,
12+ straight_through_estimator)
13+ """
14+ from ngclearn .components .jaxComponent import JaxComponent
15+ from jax import numpy as jnp , random , jit , nn
16+ from functools import partial
17+ from ngclearn .utils import tensorstats
18+ from ngcsimlib .deprecators import deprecate_args
19+ from ngcsimlib .logger import info , warn
720from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
821 step_euler , step_rk2
922from ngclearn .utils .surrogate_fx import (secant_lif_estimator , arctan_estimator ,
1023 triangular_estimator ,
1124 straight_through_estimator )
1225
13- @jit
14- def _update_times (t , s , tols ):
15- """
16- Updates time-of-last-spike (tols) variable.
17-
18- Args:
19- t: current time (a scalar/int value)
20-
21- s: binary spike vector
22-
23- tols: current time-of-last-spike variable
24-
25- Returns:
26- updated tols variable
27- """
28- _tols = (1. - s ) * tols + (s * t )
29- return _tols
26+ from ngcsimlib .compilers .process import transition
27+ #from ngcsimlib.component import Component
28+ from ngcsimlib .compartment import Compartment
3029
3130@jit
3231def _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_decay = 1. ): ## raw voltage dynamics
@@ -41,37 +40,6 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
4140 dv_dt = _dfv_internal (j , v , rfr , tau_m , refract_T , v_rest , v_decay )
4241 return dv_dt
4342
44- #@partial(jit, static_argnums=[7, 8, 9, 10, 11, 12])
45- def _run_cell (dt , j , v , v_thr , v_theta , rfr , skey , tau_m , v_rest , v_reset ,
46- v_decay , refract_T , integType = 0 ):
47- ### Runs leaky integrator (or leaky integrate-and-fire; LIF) neuronal dynamics.
48- _v_thr = v_theta + v_thr ## calc present voltage threshold
49- #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
50- ## update voltage / membrane potential
51- v_params = (j , rfr , tau_m , refract_T , v_rest , v_decay )
52- if integType == 1 :
53- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
54- else : #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
55- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
56- ## obtain action potentials/spikes
57- s = (_v > _v_thr ).astype (jnp .float32 )
58- ## update refractory variables
59- _rfr = (rfr + dt ) * (1. - s )
60- ## perform hyper-polarization of neuronal cells
61- _v = _v * (1. - s ) + s * v_reset
62-
63- raw_s = s + 0 ## preserve un-altered spikes
64- ############################################################################
65- ## this is a spike post-processing step
66- if skey is not None :
67- m_switch = (jnp .sum (s ) > 0. ).astype (jnp .float32 ) ## TODO: not batch-able
68- rS = s * random .uniform (skey , s .shape )
69- rS = nn .one_hot (jnp .argmax (rS , axis = 1 ), num_classes = s .shape [1 ],
70- dtype = jnp .float32 )
71- s = s * (1. - m_switch ) + rS * m_switch
72- ############################################################################
73- return _v , s , raw_s , _rfr
74-
7543#@partial(jit, static_argnums=[3, 4])
7644def _update_theta (dt , v_theta , s , tau_theta , theta_plus = 0.05 ):
7745 ### Runs homeostatic threshold update dynamics one step (via Euler integration).
@@ -159,7 +127,7 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
159127
160128 lower_clamp_voltage: if True, this will ensure voltage never is below
161129 the value of `v_rest` (default: True)
162- """
130+ """ ## batch_size arg?
163131
164132 @deprecate_args (thr_jitter = None )
165133 def __init__ (self , name , n_units , tau_m , resist_m = 1. , thr = - 52. , v_rest = - 65. ,
@@ -220,41 +188,61 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
220188 units = "ms" ) ## time-of-last-spike
221189 self .surrogate = Compartment (restVals + 1. , display_name = "Surrogate State Value" )
222190
191+ @transition (output_compartments = ["v" , "s" , "s_raw" , "rfr" , "thr_theta" , "tols" , "key" , "surrogate" ])
223192 @staticmethod
224- def _advance_state (t , dt , tau_m , resist_m , v_rest , v_reset , v_decay , refract_T ,
225- thr , tau_theta , theta_plus , one_spike , lower_clamp_voltage ,
226- intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols ):
193+ def advance_state (
194+ t , dt , tau_m , resist_m , v_rest , v_reset , v_decay , refract_T , thr , tau_theta , theta_plus ,
195+ one_spike , lower_clamp_voltage , intgFlag , d_spike_fx , key , j , v , rfr , thr_theta , tols
196+ ):
227197 skey = None ## this is an empty dkey if single_spike mode turned off
228198 if one_spike :
229199 key , skey = random .split (key , 2 )
230200 ## run one integration step for neuronal dynamics
231201 j = j * resist_m
232- v , s , raw_spikes , rfr = _run_cell (dt , j , v , thr , thr_theta , rfr , skey ,
233- tau_m , v_rest , v_reset , v_decay ,
234- refract_T , intgFlag )
235- surrogate = d_spike_fx (v , thr + thr_theta )
202+ ############################################################################
203+ ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
204+ _v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
205+ #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
206+ ## update voltage / membrane potential
207+ v_params = (j , rfr , tau_m , refract_T , v_rest , v_decay )
208+ if intgFlag == 1 :
209+ _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
210+ else : #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
211+ _ , _v = step_euler (0. , v , _dfv , dt , v_params )
212+ ## obtain action potentials/spikes
213+ s = (_v > _v_thr ).astype (jnp .float32 )
214+ ## update refractory variables
215+ _rfr = (rfr + dt ) * (1. - s )
216+ ## perform hyper-polarization of neuronal cells
217+ _v = _v * (1. - s ) + s * v_reset
218+
219+ raw_s = s + 0 ## preserve un-altered spikes
220+ ############################################################################
221+ ## this is a spike post-processing step
222+ if skey is not None :
223+ m_switch = (jnp .sum (s ) > 0. ).astype (jnp .float32 ) ## TODO: not batch-able
224+ rS = s * random .uniform (skey , s .shape )
225+ rS = nn .one_hot (jnp .argmax (rS , axis = 1 ), num_classes = s .shape [1 ],
226+ dtype = jnp .float32 )
227+ s = s * (1. - m_switch ) + rS * m_switch
228+ ############################################################################
229+ raw_spikes = raw_s
230+ v = _v
231+ rfr = _rfr
232+
233+ surrogate = d_spike_fx (v , _v_thr ) #d_spike_fx(v, thr + thr_theta)
236234 if tau_theta > 0. :
237235 ## run one integration step for threshold dynamics
238236 thr_theta = _update_theta (dt , thr_theta , raw_spikes , tau_theta , theta_plus )
239237 ## update tols
240- tols = _update_times ( t , s , tols )
238+ tols = ( 1. - s ) * tols + ( s * t )
241239 if lower_clamp_voltage : ## ensure voltage never < v_rest
242240 v = jnp .maximum (v , v_rest )
243241 return v , s , raw_spikes , rfr , thr_theta , tols , key , surrogate
244242
245- @resolver (_advance_state )
246- def advance_state (self , v , s , s_raw , rfr , thr_theta , tols , key , surrogate ):
247- self .v .set (v )
248- self .s .set (s )
249- self .s_raw .set (s_raw )
250- self .rfr .set (rfr )
251- self .thr_theta .set (thr_theta )
252- self .tols .set (tols )
253- self .key .set (key )
254- self .surrogate .set (surrogate )
255-
243+ @transition (output_compartments = ["j" , "v" , "s" , "s_raw" , "rfr" , "tols" , "surrogate" ])
256244 @staticmethod
257- def _reset (batch_size , n_units , v_rest , refract_T ):
245+ def reset (batch_size , n_units , v_rest , refract_T ):
258246 restVals = jnp .zeros ((batch_size , n_units ))
259247 j = restVals #+ 0
260248 v = restVals + v_rest
@@ -266,16 +254,6 @@ def _reset(batch_size, n_units, v_rest, refract_T):
266254 surrogate = restVals + 1.
267255 return j , v , s , s_raw , rfr , tols , surrogate
268256
269- @resolver (_reset )
270- def reset (self , j , v , s , s_raw , rfr , tols , surrogate ):
271- self .j .set (j )
272- self .v .set (v )
273- self .s .set (s )
274- self .s_raw .set (s_raw )
275- self .rfr .set (rfr )
276- self .tols .set (tols )
277- self .surrogate .set (surrogate )
278-
279257 def save (self , directory , ** kwargs ):
280258 ## do a protected save of constants, depending on whether they are floats or arrays
281259 tau_m = (self .tau_m if isinstance (self .tau_m , float )
0 commit comments