22from jax import numpy as jnp , random , jit , nn
33from functools import partial
44from ngclearn .utils import tensorstats
5- from ngcsimlib . deprecators import deprecate_args
5+ from ngcsimlib import deprecate_args
66from ngcsimlib .logger import info , warn
7- from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
8- step_euler , step_rk2 , step_rk4
7+ from ngclearn .utils .diffeq .ode_utils import get_integrator_code , step_euler , step_rk2 , step_rk4
98
10- from ngcsimlib .compilers .process import transition
11- #from ngcsimlib.component import Component
9+ from ngcsimlib .parser import compilable
1210from ngcsimlib .compartment import Compartment
1311
1412
@@ -113,7 +111,6 @@ class HodgkinHuxleyCell(JaxComponent): ## Hodgkin-Huxley spiking cell
113111 at an increase in computational cost (and simulation time)
114112 """
115113
116- # Define Functions
117114 def __init__ (
118115 self , name , n_units , tau_v , resist_m = 1. , v_Na = 115. , v_K = - 35. , v_L = 10.6 , g_Na = 100. , g_K = 5. , g_L = 0.3 , thr = 4. ,
119116 spike_reset = False , v_reset = 0. , integration_type = "euler" , ** kwargs
@@ -126,7 +123,7 @@ def __init__(
126123
127124 ## cell properties / biophysical parameter setup (affects ODE integration)
128125 self .tau_v = tau_v ## membrane time constant
129- self .R_m = resist_m ## resistance value
126+ self .resist_m = resist_m ## resistance value R_m
130127 self .spike_reset = spike_reset
131128 self .thr = thr # mV ## base value for threshold
132129 self .v_reset = v_reset ## base value to reset voltage to (if spike_reset = True)
@@ -151,38 +148,49 @@ def __init__(
151148 self .s = Compartment (restVals , display_name = "Spike pulse" )
152149 self .tols = Compartment (restVals , display_name = "Time-of-last-spike" ) ## time-of-last-spike
153150
154- @transition (output_compartments = ["v" , "m" , "n" , "h" , "s" , "tols" ])
155- @staticmethod
156- def advance_state (
157- t , dt , spike_reset , v_reset , thr , tau_v , R_m , g_Na , g_K , g_L , v_Na , v_K , v_L , j , v , m , n , h , tols , intgFlag
158- ):
159- _j = j * R_m
160- alpha_n_of_v , beta_n_of_v , alpha_m_of_v , beta_m_of_v , alpha_h_of_v , beta_h_of_v = _calc_biophysical_constants (v )
151+ #@transition(output_compartments=["v", "m", "n", "h", "s", "tols"])
152+ #@staticmethod
153+ @compilable
154+ def advance_state (self , t , dt ): #t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag
155+ _j = self .j .get () * self .resist_m
156+ alpha_n_of_v , beta_n_of_v , alpha_m_of_v , beta_m_of_v , alpha_h_of_v , beta_h_of_v = _calc_biophysical_constants (self .v .get ())
161157 ## integrate voltage / membrane potential
162- if intgFlag == 1 : ## midpoint method
163- _ , _v = step_rk2 (0. , v , dv_dt , dt , (_j , m + 0. , n + 0. , h + 0. , tau_v , g_Na , g_K , g_L , v_Na , v_K , v_L ))
158+ if self .intgFlag == 1 : ## midpoint method
159+ _ , _v = step_rk2 (
160+ 0. , self .v .get (), dv_dt , dt ,
161+ (_j , self .m .get () + 0. , self .n .get () + 0. , self .h .get () + 0. , self .tau_v , self .g_Na , self .g_K ,
162+ self .g_L , self .v_Na , self .v_K , self .v_L )
163+ )
164164 ## next, integrate different channels
165- _ , _n = step_rk2 (0. , n , dx_dt , dt , (alpha_n_of_v , beta_n_of_v ))
166- _ , _m = step_rk2 (0. , m , dx_dt , dt , (alpha_m_of_v , beta_m_of_v ))
167- _ , _h = step_rk2 (0. , h , dx_dt , dt , (alpha_h_of_v , beta_h_of_v ))
168- elif intgFlag == 4 : ## Runge-Kutta 4th order
169- _ , _v = step_rk4 (0. , v , dv_dt , dt , (_j , m + 0. , n + 0. , h + 0. , tau_v , g_Na , g_K , g_L , v_Na , v_K , v_L ))
165+ _ , _n = step_rk2 (0. , self .n .get (), dx_dt , dt , (alpha_n_of_v , beta_n_of_v ))
166+ _ , _m = step_rk2 (0. , self .m .get (), dx_dt , dt , (alpha_m_of_v , beta_m_of_v ))
167+ _ , _h = step_rk2 (0. , self .h .get (), dx_dt , dt , (alpha_h_of_v , beta_h_of_v ))
168+ elif self .intgFlag == 4 : ## Runge-Kutta 4th order
169+ _ , _v = step_rk4 (
170+ 0. , self .v .get (), dv_dt , dt ,
171+ (_j , self .m .get () + 0. , self .n .get () + 0. , self .h .get () + 0. , self .tau_v , self .g_Na , self .g_K ,
172+ self .g_L , self .v_Na , self .v_K , self .v_L )
173+ )
170174 ## next, integrate different channels
171- _ , _n = step_rk4 (0. , n , dx_dt , dt , (alpha_n_of_v , beta_n_of_v ))
172- _ , _m = step_rk4 (0. , m , dx_dt , dt , (alpha_m_of_v , beta_m_of_v ))
173- _ , _h = step_rk4 (0. , h , dx_dt , dt , (alpha_h_of_v , beta_h_of_v ))
175+ _ , _n = step_rk4 (0. , self . n . get () , dx_dt , dt , (alpha_n_of_v , beta_n_of_v ))
176+ _ , _m = step_rk4 (0. , self . m . get () , dx_dt , dt , (alpha_m_of_v , beta_m_of_v ))
177+ _ , _h = step_rk4 (0. , self . h . get () , dx_dt , dt , (alpha_h_of_v , beta_h_of_v ))
174178 else : # integType == 0 (default -- Euler)
175- _ , _v = step_euler (0. , v , dv_dt , dt , (_j , m + 0. , n + 0. , h + 0. , tau_v , g_Na , g_K , g_L , v_Na , v_K , v_L ))
179+ _ , _v = step_euler (
180+ 0. , self .v .get (), dv_dt , dt ,
181+ (_j , self .m .get () + 0. , self .n .get () + 0. , self .h .get () + 0. , self .tau_v , self .g_Na , self .g_K ,
182+ self .g_L , self .v_Na , self .v_K , self .v_L )
183+ )
176184 ## next, integrate different channels
177- _ , _n = step_euler (0. , n , dx_dt , dt , (alpha_n_of_v , beta_n_of_v ))
178- _ , _m = step_euler (0. , m , dx_dt , dt , (alpha_m_of_v , beta_m_of_v ))
179- _ , _h = step_euler (0. , h , dx_dt , dt , (alpha_h_of_v , beta_h_of_v ))
185+ _ , _n = step_euler (0. , self . n . get () , dx_dt , dt , (alpha_n_of_v , beta_n_of_v ))
186+ _ , _m = step_euler (0. , self . m . get () , dx_dt , dt , (alpha_m_of_v , beta_m_of_v ))
187+ _ , _h = step_euler (0. , self . h . get () , dx_dt , dt , (alpha_h_of_v , beta_h_of_v ))
180188 ## obtain action potentials/spikes/pulses
181- s = (_v > thr ) * 1.
182- if spike_reset : ## if spike-reset used, variables snapped back to initial conditions
189+ s = (_v > self . thr ) * 1.
190+ if self . spike_reset : ## if spike-reset used, variables snapped back to initial conditions
183191 alpha_n_of_v , beta_n_of_v , alpha_m_of_v , beta_m_of_v , alpha_h_of_v , beta_h_of_v = (
184- _calc_biophysical_constants (v * 0 + v_reset ))
185- _v = _v * (1. - s ) + s * v_reset
192+ _calc_biophysical_constants (self . v . get () * 0 + self . v_reset ))
193+ _v = _v * (1. - s ) + s * self . v_reset
186194 _n = _n * (1. - s ) + s * (alpha_n_of_v / (alpha_n_of_v + beta_n_of_v ))
187195 _m = _m * (1. - s ) + s * (alpha_m_of_v / (alpha_m_of_v + beta_m_of_v ))
188196 _h = _h * (1. - s ) + s * (alpha_h_of_v / (alpha_h_of_v + beta_h_of_v ))
@@ -191,32 +199,40 @@ def advance_state(
191199 m = _m
192200 n = _n
193201 h = _h
194- tols = (1. - s ) * tols + (s * t ) ## update tols
202+ ## update time-of-last spike variable(s)
203+ self .tols .set ((1. - s ) * self .tols .get () + (s * t ))
195204
196- return v , m , n , h , s , tols
205+ self .v .set (v )
206+ self .m .set (m )
207+ self .n .set (n )
208+ self .h .set (h )
209+ self .s .set (s )
197210
198- @transition (output_compartments = ["j" , "v" , "m" , "n" , "h" , "s" , "tols" ])
199- @staticmethod
200- def reset (batch_size , n_units ):
201- restVals = jnp .zeros ((batch_size , n_units ))
211+ @compilable
212+ def reset (self ):
213+ restVals = jnp .zeros ((self .batch_size , self .n_units ))
202214 v = restVals # + 0
203215 alpha_n_of_v , beta_n_of_v , alpha_m_of_v , beta_m_of_v , alpha_h_of_v , beta_h_of_v = _calc_biophysical_constants (v )
204- j = restVals #+ 0
216+ if not self .j .targeted :
217+ self .j .set (restVals )
205218 n = alpha_n_of_v / (alpha_n_of_v + beta_n_of_v )
206219 m = alpha_m_of_v / (alpha_m_of_v + beta_m_of_v )
207220 h = alpha_h_of_v / (alpha_h_of_v + beta_h_of_v )
208- s = restVals #+ 0
209- tols = restVals #+ 0
210- return j , v , m , n , h , s , tols
211-
212- def save (self , directory , ** kwargs ):
213- file_name = directory + "/" + self .name + ".npz"
214- #jnp.savez(file_name, threshold=self.thr.value)
215-
216- def load (self , directory , seeded = False , ** kwargs ):
217- file_name = directory + "/" + self .name + ".npz"
218- data = jnp .load (file_name )
219- #self.thr.set( data['threshold'] )
221+ self .v .set (v )
222+ self .n .set (n )
223+ self .m .set (m )
224+ self .h .set (h )
225+ self .s .set (restVals )
226+ self .tols .set (restVals )
227+
228+ # def save(self, directory, **kwargs):
229+ # file_name = directory + "/" + self.name + ".npz"
230+ # #jnp.savez(file_name, threshold=self.thr.value)
231+ #
232+ # def load(self, directory, seeded=False, **kwargs):
233+ # file_name = directory + "/" + self.name + ".npz"
234+ # data = jnp.load(file_name)
235+ # #self.thr.set( data['threshold'] )
220236
221237 @classmethod
222238 def help (cls ): ## component help function
@@ -258,7 +274,7 @@ def help(cls): ## component help function
258274 return info
259275
260276 def __repr__ (self ):
261- comps = [varname for varname in dir (self ) if Compartment . is_compartment (getattr (self , varname ))]
277+ comps = [varname for varname in dir (self ) if isinstance (getattr (self , varname ), Compartment )]
262278 maxlen = max (len (c ) for c in comps ) + 5
263279 lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
264280 for c in comps :
0 commit comments