1- from jax import numpy as jnp , jit
2- from ngclearn .utils import tensorstats
3- from ngclearn import resolver , Component , Compartment
41from ngclearn .components .jaxComponent import JaxComponent
2+ from jax import numpy as jnp , random , jit , nn
3+ from functools import partial
4+ from ngclearn .utils import tensorstats
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
57from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
68 step_euler , step_rk2
79
8- @jit
9- def _update_times (t , s , tols ):
10- """
11- Updates time-of-last-spike (tols) variable.
12-
13- Args:
14- t: current time (a scalar/int value)
15-
16- s: binary spike vector
10+ from ngcsimlib .compilers .process import transition
11+ #from ngcsimlib.component import Component
12+ from ngcsimlib .compartment import Compartment
1713
18- tols: current time-of-last-spike variable
19-
20- Returns:
21- updated tols variable
22- """
23- _tols = (1. - s ) * tols + (s * t )
24- return _tols
2514
2615@jit
2716def _dfv_internal (j , v , w , b , tau_m ): ## raw voltage dynamics
@@ -55,39 +44,6 @@ def _post_process(s, _v, _w, v, w, c, d): ## internal post-processing routine
5544 w_next = _w * (1. - s ) + s * (w + d )
5645 return v_next , w_next
5746
58- @jit
59- def _emit_spike (v , v_thr ):
60- s = (v > v_thr ).astype (jnp .float32 )
61- return s
62-
63- @jit
64- def _modify_current (j , R_m ):
65- _j = j * R_m
66- return _j
67-
68- def _run_cell (dt , j , v , s , w , v_thr = 30. , tau_m = 1. , tau_w = 50. , b = 0.2 , c = - 65. , d = 8. ,
69- R_m = 1. , integType = 0 ):
70- ## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
71- a = 1. / tau_w ## we map time constant to variable "a" (a = 1/tau_w)
72- _j = _modify_current (j , R_m )
73- #_j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
74- ## check for spikes
75- s = _emit_spike (v , v_thr )
76- ## for non-spikes, evolve according to dynamics
77- if integType == 1 :
78- v_params = (_j , w , b , tau_m )
79- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params ) #_v = step_rk2(v, v_params, _dfv, dt)
80- w_params = (_j , v , b , tau_w )
81- _ , _w = step_rk2 (0. , w , _dfw , dt , w_params ) #_w = step_rk2(w, w_params, _dfw, dt)
82- else : # integType == 0 (default -- Euler)
83- v_params = (_j , w , b , tau_m )
84- _ , _v = step_euler (0. , v , _dfv , dt , v_params ) #_v = step_euler(v, v_params, _dfv, dt)
85- w_params = (_j , v , b , tau_w )
86- _ , _w = step_euler (0. , w , _dfw , dt , w_params ) #_w = step_euler(w, w_params, _dfw, dt)
87- ## for spikes, snap to particular states
88- _v , _w = _post_process (s , _v , _w , v , w , c , d )
89- return _v , _w , s
90-
9147class IzhikevichCell (JaxComponent ): ## Izhikevich neuronal cell
9248 """
9349 A spiking cell based on Izhikevich's model of neuronal dynamics. Note that
@@ -197,24 +153,38 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65.
197153 self .s = Compartment (restVals )
198154 self .tols = Compartment (restVals ) ## time-of-last-spike
199155
156+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
200157 @staticmethod
201- def _advance_state (t , dt , tau_m , tau_w , v_thr , coupling , v_reset , w_reset , R_m ,
158+ def advance_state (t , dt , tau_m , tau_w , v_thr , coupling , v_reset , w_reset , R_m ,
202159 intgFlag , j , v , w , s , tols ):
203- v , w , s = _run_cell (dt , j , v , s , w , v_thr = v_thr , tau_m = tau_m , tau_w = tau_w ,
204- b = coupling , c = v_reset , d = w_reset , R_m = R_m , integType = intgFlag )
205- tols = _update_times (t , s , tols )
160+ ## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
161+ a = 1. / tau_w ## we map time constant to variable "a" (a = 1/tau_w)
162+ _j = j * R_m
163+ # _j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
164+ ## check for spikes
165+ s = (v > v_thr ) * 1.
166+ ## for non-spikes, evolve according to dynamics
167+ if intgFlag == 1 :
168+ v_params = (_j , w , coupling , tau_m )
169+ _ , _v = step_rk2 (0. , v , _dfv , dt , v_params ) # _v = step_rk2(v, v_params, _dfv, dt)
170+ w_params = (_j , v , coupling , tau_w )
171+ _ , _w = step_rk2 (0. , w , _dfw , dt , w_params ) # _w = step_rk2(w, w_params, _dfw, dt)
172+ else : # integType == 0 (default -- Euler)
173+ v_params = (_j , w , coupling , tau_m )
174+ _ , _v = step_euler (0. , v , _dfv , dt , v_params ) # _v = step_euler(v, v_params, _dfv, dt)
175+ w_params = (_j , v , coupling , tau_w )
176+ _ , _w = step_euler (0. , w , _dfw , dt , w_params ) # _w = step_euler(w, w_params, _dfw, dt)
177+ ## for spikes, snap to particular states
178+ _v , _w = _post_process (s , _v , _w , v , w , v_reset , w_reset )
179+ v = _v
180+ w = _w
181+
182+ tols = (1. - s ) * tols + (s * t ) ## update tols
206183 return j , v , w , s , tols
207184
208- @resolver (_advance_state )
209- def advance_state (self , j , v , w , s , tols ):
210- self .j .set (j )
211- self .w .set (w )
212- self .v .set (v )
213- self .s .set (s )
214- self .tols .set (tols )
215-
185+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
216186 @staticmethod
217- def _reset (batch_size , n_units , v0 , w0 ):
187+ def reset (batch_size , n_units , v0 , w0 ):
218188 restVals = jnp .zeros ((batch_size , n_units ))
219189 j = restVals # None
220190 v = restVals + v0
@@ -223,14 +193,6 @@ def _reset(batch_size, n_units, v0, w0):
223193 tols = restVals #+ 0
224194 return j , v , w , s , tols
225195
226- @resolver (_reset )
227- def reset (self , j , v , w , s , tols ):
228- self .j .set (j )
229- self .v .set (v )
230- self .w .set (w )
231- self .s .set (s )
232- self .tols .set (tols )
233-
234196 @classmethod
235197 def help (cls ): ## component help function
236198 properties = {
0 commit comments