1- from jax import numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
31from ngclearn .components .jaxComponent import JaxComponent
2+ from jax import numpy as jnp , random , jit , nn
3+ from functools import partial
44from ngclearn .utils import tensorstats
55from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
67from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
78 step_euler , step_rk2
8-
9- @jit
10- def _update_times (t , s , tols ):
11- """
12- Updates time-of-last-spike (tols) variable.
13-
14- Args:
15- t: current time (a scalar/int value)
16-
17- s: binary spike vector
18-
19- tols: current time-of-last-spike variable
20-
21- Returns:
22- updated tols variable
23- """
24- _tols = (1. - s ) * tols + (s * t )
25- return _tols
9+ from ngcsimlib .compilers .process import transition
10+ #from ngcsimlib.component import Component
11+ from ngcsimlib .compartment import Compartment
2612
2713@jit
2814def _dfv_internal (j , v , w , tau_m , v_rest , sharpV , vT , R_m ): ## raw voltage dynamics
@@ -46,30 +32,6 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
4632 dv_dt = _dfw_internal (j , v , w , a , tau_m , v_rest )
4733 return dv_dt
4834
49- @jit
50- def _emit_spike (v , v_thr ):
51- s = (v > v_thr ).astype (jnp .float32 )
52- return s
53-
54- #@partial(jit, static_argnums=[10])
55- def _run_cell (dt , j , v , w , v_thr , tau_m , tau_w , a , b , sharpV , vT ,
56- v_rest , v_reset , R_m , integType = 0 ):
57- if integType == 1 : ## RK-2/midpoint
58- v_params = (j , w , tau_m , v_rest , sharpV , vT , R_m )
59- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
60- w_params = (j , v , a , tau_w , v_rest )
61- _ , _w = step_rk2 (0. , w , _dfw , dt , w_params )
62- else : # integType == 0 (default -- Euler)
63- v_params = (j , w , tau_m , v_rest , sharpV , vT , R_m )
64- _ , _v = step_euler (0. , v , _dfv , dt , v_params )
65- w_params = (j , v , a , tau_w , v_rest )
66- _ , _w = step_euler (0. , w , _dfw , dt , w_params )
67- s = _emit_spike (_v , v_thr )
68- ## hyperpolarize/reset/snap variables
69- _v = _v * (1. - s ) + s * v_reset
70- _w = _w * (1. - s ) + s * (_w + b )
71- return _v , _w , s
72-
7335class AdExCell (JaxComponent ):
7436 """
7537 The AdEx (adaptive exponential leaky integrate-and-fire) neuronal cell
@@ -136,10 +98,10 @@ class AdExCell(JaxComponent):
13698 """
13799
138100 @deprecate_args (v_thr = "thr" )
139- def __init__ (self , name , n_units , tau_m = 15. , resist_m = 1. , tau_w = 400. ,
140- v_sharpness = 2. , intrinsic_mem_thr = - 55. , thr = 5. , v_rest = - 72 . ,
141- v_reset = - 75. , a = 0.1 , b = 0.75 , v0 = - 70. , w0 = 0. ,
142- integration_type = "euler" , batch_size = 1 , ** kwargs ):
101+ def __init__ (
102+ self , name , n_units , tau_m = 15. , resist_m = 1. , tau_w = 400. , v_sharpness = 2. , intrinsic_mem_thr = - 55. , thr = 5. ,
103+ v_rest = - 72. , v_reset = - 75. , a = 0.1 , b = 0.75 , v0 = - 70. , w0 = 0. , integration_type = "euler" , batch_size = 1 , ** kwargs
104+ ):
143105 super ().__init__ (name , ** kwargs )
144106
145107 ## Integration properties
@@ -174,24 +136,32 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400.,
174136 self .tols = Compartment (restVals , display_name = "Time-of-Last-Spike" ,
175137 units = "ms" ) ## time-of-last-spike
176138
139+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
177140 @staticmethod
178- def _advance_state (t , dt , tau_m , R_m , tau_w , thr , a , b , sharpV , vT ,
179- v_rest , v_reset , intgFlag , j , v , w , tols ):
180- v , w , s = _run_cell (dt , j , v , w , thr , tau_m , tau_w , a , b , sharpV , vT ,
181- v_rest , v_reset , R_m , intgFlag )
182- tols = _update_times (t , s , tols )
141+ def advance_state (
142+ t , dt , tau_m , R_m , tau_w , thr , a , b , sharpV , vT , v_rest , v_reset , intgFlag , j , v , w , tols
143+ ):
144+ if intgFlag == 1 : ## RK-2/midpoint
145+ v_params = (j , w , tau_m , v_rest , sharpV , vT , R_m )
146+ _ , _v = step_rk2 (0. , v , _dfv , dt , v_params )
147+ w_params = (j , v , a , tau_w , v_rest )
148+ _ , _w = step_rk2 (0. , w , _dfw , dt , w_params )
149+ else : # intgFlag == 0 (default -- Euler)
150+ v_params = (j , w , tau_m , v_rest , sharpV , vT , R_m )
151+ _ , _v = step_euler (0. , v , _dfv , dt , v_params )
152+ w_params = (j , v , a , tau_w , v_rest )
153+ _ , _w = step_euler (0. , w , _dfw , dt , w_params )
154+ s = (_v > thr ) * 1. ## emit spikes/pulses
155+ ## hyperpolarize/reset/snap variables
156+ v = _v * (1. - s ) + s * v_reset
157+ w = _w * (1. - s ) + s * (_w + b )
158+
159+ tols = (1. - s ) * tols + (s * t ) ## update time-of-last spike variable(s)
183160 return j , v , w , s , tols
184161
185- @resolver (_advance_state )
186- def advance_state (self , j , v , w , s , tols ):
187- self .j .set (j )
188- self .w .set (w )
189- self .v .set (v )
190- self .s .set (s )
191- self .tols .set (tols )
192-
162+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
193163 @staticmethod
194- def _reset (batch_size , n_units , v0 , w0 ):
164+ def reset (batch_size , n_units , v0 , w0 ):
195165 restVals = jnp .zeros ((batch_size , n_units ))
196166 j = restVals # None
197167 v = restVals + v0
@@ -200,14 +170,6 @@ def _reset(batch_size, n_units, v0, w0):
200170 tols = restVals #+ 0
201171 return j , v , w , s , tols
202172
203- @resolver (_reset )
204- def reset (self , j , v , w , s , tols ):
205- self .j .set (j )
206- self .v .set (v )
207- self .w .set (w )
208- self .s .set (s )
209- self .tols .set (tols )
210-
211173 @classmethod
212174 def help (cls ): ## component help function
213175 properties = {
0 commit comments