22from jax import numpy as jnp , random , jit , nn
33from functools import partial
44from ngclearn .utils import tensorstats
5- from ngcsimlib . deprecators import deprecate_args
5+ from ngcsimlib import deprecate_args
66from ngcsimlib .logger import info , warn
77from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
88 step_euler , step_rk2
99
10- from ngcsimlib .compilers .process import transition
11- #from ngcsimlib.component import Component
10+ from ngcsimlib .parser import compilable
1211from ngcsimlib .compartment import Compartment
1312
1413
@@ -34,7 +33,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
3433 dv_dt = _dfw_internal (j , v , w , a , b , g , tau_m )
3534 return dv_dt
3635
37- class FitzhughNagumoCell (JaxComponent ):
36+ class FitzhughNagumoCell (JaxComponent ): ## F-H cell
3837 """
3938 The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification
4039 of the Hodgkin-Huxley (squid axon) model. This cell model iteratively evolves
@@ -103,10 +102,10 @@ class FitzhughNagumoCell(JaxComponent):
103102 at an increase in computational cost (and simulation time)
104103 """
105104
106- # Define Functions
107- def __init__ ( self , name , n_units , tau_m = 1. , resist_m = 1. , tau_w = 12.5 , alpha = 0.7 ,
108- beta = 0.8 , gamma = 3. , v0 = 0. , w0 = 0. , v_thr = 1.07 , spike_reset = False ,
109- integration_type = "euler" , ** kwargs ):
105+ def __init__ (
106+ self , name , n_units , tau_m = 1. , resist_m = 1. , tau_w = 12.5 , alpha = 0.7 , beta = 0.8 , gamma = 3. , v0 = 0. , w0 = 0. ,
107+ v_thr = 1.07 , spike_reset = False , integration_type = "euler" , ** kwargs
108+ ):
110109 super ().__init__ (name , ** kwargs )
111110
112111 ## Integration properties
@@ -115,7 +114,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
115114
116115 ## Cell properties
117116 self .tau_m = tau_m
118- self .R_m = resist_m
117+ self .resist_m = resist_m ## resistance R_m
119118 self .tau_w = tau_w
120119 self .alpha = alpha
121120 self .beta = beta
@@ -138,41 +137,44 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
138137 self .s = Compartment (restVals )
139138 self .tols = Compartment (restVals ) ## time-of-last-spike
140139
141- @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
142- @staticmethod
143- def advance_state (t , dt , tau_m , R_m , tau_w , v_thr , spike_reset , v0 , w0 , alpha ,
144- beta , gamma , intgFlag , j , v , w , tols ):
145- j_mod = j * R_m
146- if intgFlag == 1 :
147- v_params = (j_mod , w , alpha , beta , gamma , tau_m )
148- _ , _v = step_rk2 (0. , v , _dfv , dt , v_params ) # _v = step_rk2(v, v_params, _dfv, dt)
149- w_params = (j_mod , v , alpha , beta , gamma , tau_w )
150- _ , _w = step_rk2 (0. , w , _dfw , dt , w_params ) # _w = step_rk2(w, w_params, _dfw, dt)
140+ @compilable
141+ def advance_state (self , t , dt ):
142+ j_mod = self .j .get () * self .resist_m
143+ if self .intgFlag == 1 :
144+ v_params = (j_mod , self .w .get (), self .alpha , self .beta , self .gamma , self .tau_m )
145+ _ , _v = step_rk2 (0. , self .v .get (), _dfv , dt , v_params ) # _v = step_rk2(v, v_params, _dfv, dt)
146+ w_params = (j_mod , self .v .get (), self .alpha , self .beta , self .gamma , self .tau_w )
147+ _ , _w = step_rk2 (0. , self .w .get (), _dfw , dt , w_params ) # _w = step_rk2(w, w_params, _dfw, dt)
151148 else : # integType == 0 (default -- Euler)
152- v_params = (j_mod , w , alpha , beta , gamma , tau_m )
153- _ , _v = step_euler (0. , v , _dfv , dt , v_params ) # _v = step_euler(v, v_params, _dfv, dt)
154- w_params = (j_mod , v , alpha , beta , gamma , tau_w )
155- _ , _w = step_euler (0. , w , _dfw , dt , w_params ) # _w = step_euler(w, w_params, _dfw, dt)
156- s = (_v > v_thr ) * 1.
149+ v_params = (j_mod , self . w . get (), self . alpha , self . beta , self . gamma , self . tau_m )
150+ _ , _v = step_euler (0. , self . v . get () , _dfv , dt , v_params ) # _v = step_euler(v, v_params, _dfv, dt)
151+ w_params = (j_mod , self . v . get (), self . alpha , self . beta , self . gamma , self . tau_w )
152+ _ , _w = step_euler (0. , self . w . get () , _dfw , dt , w_params ) # _w = step_euler(w, w_params, _dfw, dt)
153+ s = (_v > self . v_thr ) * 1.
157154 v = _v
158155 w = _w
159156
160- if spike_reset : ## if spike-reset used, variables snapped back to initial conditions
161- v = v * (1. - s ) + s * v0
162- w = w * (1. - s ) + s * w0
163- tols = (1. - s ) * tols + (s * t ) ## update tols
164- return j , v , w , s , tols
165-
166- @transition (output_compartments = ["j" , "v" , "w" , "s" , "tols" ])
167- @staticmethod
168- def reset (batch_size , n_units , v0 , w0 ):
169- restVals = jnp .zeros ((batch_size , n_units ))
170- j = restVals # None
171- v = restVals + v0
172- w = restVals + w0
173- s = restVals #+ 0
174- tols = restVals #+ 0
175- return j , v , w , s , tols
157+ if self .spike_reset : ## if spike-reset used, variables snapped back to initial conditions
158+ v = v * (1. - s ) + s * self .v0
159+ w = w * (1. - s ) + s * self .w0
160+
161+ ## update time-of-last spike variable(s)
162+ self .tols .set ((1. - s ) * self .tols .get () + (s * t ))
163+
164+ # self.j.set(j) ## j is not getting modified in these dynamics
165+ self .v .set (v )
166+ self .w .set (w )
167+ self .s .set (s )
168+
169+ @compilable
170+ def reset (self ):
171+ restVals = jnp .zeros ((self .batch_size , self .n_units ))
172+ if not self .j .targeted :
173+ self .j .set (restVals )
174+ self .v .set (restVals + self .v0 )
175+ self .w .set (restVals + self .w0 )
176+ self .s .set (restVals )
177+ self .tols .set (restVals )
176178
177179 @classmethod
178180 def help (cls ): ## component help function
@@ -197,8 +199,7 @@ def help(cls): ## component help function
197199 "resist_m" : "Membrane resistance value" ,
198200 "tau_w" : "Recovery variable time constant" ,
199201 "v_thr" : "Base voltage threshold value" ,
200- "spike_reset" : "Should voltage/recover be snapped to initial "
201- "condition(s) if spike emitted?" ,
202+ "spike_reset" : "Should voltage/recover be snapped to initial condition(s) if spike emitted?" ,
202203 "alpha" : "Dimensionless recovery variable shift factor `a" ,
203204 "beta" : "Dimensionless recovery variable scale factor `b`" ,
204205 "gamma" : "Power-term divisor constant" ,
0 commit comments