1- from jax import numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
31from ngclearn .components .jaxComponent import JaxComponent
2+ from jax import numpy as jnp , random , jit , nn
3+ from functools import partial
44from ngclearn .utils import tensorstats
5+ from ngcsimlib .deprecators import deprecate_args
6+ from ngcsimlib .logger import info , warn
57from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
68 step_euler , step_rk2
79
8- @jit
9- def _update_times (t , s , tols ):
10- """
11- Updates time-of-last-spike (tols) variable.
12-
13- Args:
14- t: current time (a scalar/int value)
15-
16- s: binary spike vector
17-
18- tols: current time-of-last-spike variable
10+ from ngcsimlib .compilers .process import transition
11+ #from ngcsimlib.component import Component
12+ from ngcsimlib .compartment import Compartment
1913
20- Returns:
21- updated tols variable
22- """
23- _tols = (1. - s ) * tols + (s * t )
24- return _tols
2514
2615@jit
2716def _dfv_internal (j , v , w , a , b , g , tau_m ): ## raw voltage dynamics
@@ -45,25 +34,6 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
4534 dv_dt = _dfw_internal (j , v , w , a , b , g , tau_m )
4635 return dv_dt
4736
48- @jit
49- def _emit_spike (v , v_thr ):
50- s = (v > v_thr ).astype (jnp .float32 )
51- return s
52-
53- def _run_cell (dt , j , v , w , v_thr , tau_m , tau_w , a , b , g = 3. , integType = 0 ):
54- if integType == 1 :
55- v_params = (j , w , a , b , g , tau_m )
56- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params ) #_v = step_rk2(v, v_params, _dfv, dt)
57- w_params = (j , v , a , b , g , tau_w )
58- _ , _w = step_rk2 (0. , w , _dfw , dt , w_params ) #_w = step_rk2(w, w_params, _dfw, dt)
59- else : # integType == 0 (default -- Euler)
60- v_params = (j , w , a , b , g , tau_m )
61- _ , _v = step_euler (0. , v , _dfv , dt , v_params ) #_v = step_euler(v, v_params, _dfv, dt)
62- w_params = (j , v , a , b , g , tau_w )
63- _ , _w = step_euler (0. , w , _dfw , dt , w_params ) #_w = step_euler(w, w_params, _dfw, dt)
64- s = _emit_spike (_v , v_thr )
65- return _v , _w , s
66-
6737class FitzhughNagumoCell (JaxComponent ):
6838 """
6939 The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification
@@ -168,27 +138,34 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
168138 self .s = Compartment (restVals )
169139 self .tols = Compartment (restVals ) ## time-of-last-spike
170140
141+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
171142 @staticmethod
172- def _advance_state (t , dt , tau_m , R_m , tau_w , v_thr , spike_reset , v0 , w0 , alpha ,
143+ def advance_state (t , dt , tau_m , R_m , tau_w , v_thr , spike_reset , v0 , w0 , alpha ,
173144 beta , gamma , intgFlag , j , v , w , tols ):
174- v , w , s = _run_cell (dt , j * R_m , v , w , v_thr , tau_m , tau_w , alpha , beta ,
175- gamma , intgFlag )
145+ j_mod = j * R_m
146+ if intgFlag == 1 :
147+ v_params = (j_mod , w , alpha , beta , gamma , tau_m )
148+ _ , _v = step_rk2 (0. , v , _dfv , dt , v_params ) # _v = step_rk2(v, v_params, _dfv, dt)
149+ w_params = (j_mod , v , alpha , beta , gamma , tau_w )
150+ _ , _w = step_rk2 (0. , w , _dfw , dt , w_params ) # _w = step_rk2(w, w_params, _dfw, dt)
151+ else : # integType == 0 (default -- Euler)
152+ v_params = (j_mod , w , alpha , beta , gamma , tau_m )
153+ _ , _v = step_euler (0. , v , _dfv , dt , v_params ) # _v = step_euler(v, v_params, _dfv, dt)
154+ w_params = (j_mod , v , alpha , beta , gamma , tau_w )
155+ _ , _w = step_euler (0. , w , _dfw , dt , w_params ) # _w = step_euler(w, w_params, _dfw, dt)
156+ s = (_v > v_thr ) * 1.
157+ v = _v
158+ w = _w
159+
176160 if spike_reset : ## if spike-reset used, variables snapped back to initial conditions
177161 v = v * (1. - s ) + s * v0
178162 w = w * (1. - s ) + s * w0
179- tols = _update_times ( t , s , tols )
163+ tols = ( 1. - s ) * tols + ( s * t ) ## update tols
180164 return j , v , w , s , tols
181165
182- @resolver (_advance_state )
183- def advance_state (self , j , v , w , s , tols ):
184- self .j .set (j )
185- self .w .set (w )
186- self .v .set (v )
187- self .s .set (s )
188- self .tols .set (tols )
189-
166+ @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
190167 @staticmethod
191- def _reset (batch_size , n_units , v0 , w0 ):
168+ def reset (batch_size , n_units , v0 , w0 ):
192169 restVals = jnp .zeros ((batch_size , n_units ))
193170 j = restVals # None
194171 v = restVals + v0
@@ -197,14 +174,6 @@ def _reset(batch_size, n_units, v0, w0):
197174 tols = restVals #+ 0
198175 return j , v , w , s , tols
199176
200- @resolver (_reset )
201- def reset (self , j , v , w , s , tols ):
202- self .j .set (j )
203- self .v .set (v )
204- self .w .set (w )
205- self .s .set (s )
206- self .tols .set (tols )
207-
208177 @classmethod
209178 def help (cls ): ## component help function
210179 properties = {
0 commit comments